diff --git a/.black.toml b/.black.toml index b9123f45..ab65d233 100644 --- a/.black.toml +++ b/.black.toml @@ -16,6 +16,8 @@ exclude = ''' | buck-out | build | dist + | alembic + | gen )/ ) ''' diff --git a/.circleci/config.yml b/.circleci/config.yml index f9243eef..09c86645 100644 --- a/.circleci/config.yml +++ b/.circleci/config.yml @@ -1,6 +1,7 @@ version: 2.1 orbs: python: circleci/python@2.1.1 + aws-cli: circleci/aws-cli@3.1.5 workflows: ci: @@ -10,11 +11,16 @@ workflows: - integration_tests - build_image - build_docs + - deploy_docs: + filters: + branches: + only: + - main jobs: run_unit_tests_python_client: docker: - - image: python:3.8-bookworm + - image: python:3.10-bookworm resource_class: small parallelism: 1 steps: @@ -28,7 +34,7 @@ jobs: - run_unit_tests_python_client run_unit_tests_server: docker: - - image: python:3.8-bookworm + - image: python:3.10-bookworm environment: ML_INFRA_DATABASE_URL: postgresql://postgres@localhost/circle_test - image: circleci/postgres:12.9-postgis-ram @@ -48,7 +54,7 @@ jobs: - run_unit_tests_server build_docs: docker: - - image: python:3.8-bookworm + - image: python:3.10-bookworm resource_class: small parallelism: 1 steps: @@ -62,6 +68,25 @@ jobs: name: Build Docs command: | mkdocs build --strict + deploy_docs: + docker: + - image: python:3.10-bookworm + resource_class: small + parallelism: 1 + steps: + - add_ssh_keys: # gives write access to CircleCI worker + fingerprints: + - "76:0c:1b:9e:e3:6a:c3:5c:6f:24:91:ef:7c:54:d2:7a" + - checkout # checkout source code to working directory + - environment_setup + - install_client + - python/install-packages: + pkg-manager: pip + pip-dependency-file: requirements-docs.txt + - run: + name: Deploy Docs + command: | + mkdocs gh-deploy build_image: executor: ubuntu-large steps: @@ -69,34 +94,105 @@ jobs: - run: name: Build Docker Image command: | - docker build . -f server/Dockerfile -t llm-engine:$CIRCLE_SHA1 + docker build . -f model-engine/Dockerfile -t model-engine:$CIRCLE_SHA1 integration_tests: executor: ubuntu-large steps: - checkout + - aws-cli/setup: + role-arn: ${CIRCLECI_ROLE_ARN} + aws-region: AWS_REGION + - run: + name: Build Docker Image + command: | + docker build . -f model-engine/Dockerfile -t model-engine:$CIRCLE_SHA1 - run: name: Install minikube command: | cd $HOME curl -LO https://storage.googleapis.com/minikube/releases/latest/minikube_latest_amd64.deb sudo dpkg -i minikube_latest_amd64.deb - minikube start --vm-driver=docker --kubernetes-version=v1.23.0 --memory=14336 --cpus=8 + minikube start --vm-driver=docker --kubernetes-version=v1.23.0 --memory=49152 --cpus=14 - run: - name: Install helm + name: Install kubectl, helm command: | - cd $HOME + cd $HOME/bin curl https://raw.githubusercontent.com/helm/helm/main/scripts/get-helm-3 | bash + curl -LO "https://dl.k8s.io/release/v1.23.0/bin/linux/amd64/kubectl" + chmod +x kubectl + - run: + name: Install helm chart dependencies (Redis, Postgres, Istio) + command: | + sudo apt-get update && sudo apt-get install -y expect + pushd $HOME/project/.circleci/resources + kubectl create namespace model-engine + kubectl apply -f redis-k8s.yaml + kubectl apply -f postgres-k8s.yaml + kubectl create secret generic model-engine-postgres-credentials --from-literal=database_url=postgresql://postgres:circle_test@postgres.default:5432/circle_test + kubectl create secret generic model-engine-postgres-credentials --from-literal=database_url=postgresql://postgres:circle_test@postgres.default:5432/circle_test -n model-engine + export ISTIO_VERSION=1.15.0 + popd + curl -L https://istio.io/downloadIstio | TARGET_ARCH=x86_64 sh - + install istio-${ISTIO_VERSION}/bin/istioctl $HOME/bin + $HOME/bin/istioctl install --set profile=demo -y + kubectl create configmap default-config --from-literal=config="$(cat $HOME/project/.circleci/resources/.minikube-config-map | envsubst)" + kubectl create configmap default-config --namespace model-engine --from-literal=config="$(cat $HOME/project/.circleci/resources/.minikube-config-map | envsubst)" + cat $HOME/project/.circleci/resources/.minikube-registry-creds | envsubst | expect + minikube addons enable registry-creds + - run: + name: Pre-load model-engine image to minikube + command: | + minikube --logtostderr -v 1 image load model-engine:$CIRCLE_SHA1 + - run: + name: Pre-load integration test images to minikube + command: | + docker build -f model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile \ + --build-arg BASE_IMAGE=python:3.8-slim \ + --build-arg REQUIREMENTS_FILE="$CIRCLE_SHA1-base-requirements.txt" \ + -t temp:1.11.0-cuda11.3-cudnn8-runtime-$CIRCLE_SHA1 . + + touch $CIRCLE_SHA1-requirements.txt + echo -e "cloudpickle==2.1.0\npyyaml==6.0" > $CIRCLE_SHA1-requirements.txt + + DOCKER_BUILDKIT=1 docker build -f model-engine/model_engine_server/inference/pytorch_or_tf.user.Dockerfile \ + --build-arg BASE_IMAGE=temp:1.11.0-cuda11.3-cudnn8-runtime-$CIRCLE_SHA1 \ + --build-arg REQUIREMENTS_FILE="$CIRCLE_SHA1-requirements.txt" \ + -t $CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com/hosted-model-inference/async-pytorch:1.11.0-cuda11.3-cudnn8-runtime-$CIRCLE_SHA1-b8c25b . + rm $CIRCLE_SHA1-requirements.txt + + minikube --logtostderr -v 1 image load $CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com/hosted-model-inference/async-pytorch:1.11.0-cuda11.3-cudnn8-runtime-$CIRCLE_SHA1-b8c25b - run: name: Install helm chart command: | - cd $HOME/project/charts - helm install llm-engine llm-engine --values llm-engine/values_sample.yaml + pushd $HOME/project/charts + cat model-engine/values_circleci.yaml | envsubst > model-engine/values_circleci_subst.yaml + helm install model-engine model-engine --values model-engine/values_circleci_subst.yaml --set tag=$CIRCLE_SHA1 --atomic --debug + - run: + name: Change python version to 3.10.14 + command: | + pyenv install 3.10.14 + pyenv global 3.10.14 + - run: + name: Install integration test dependencies + command: | + export DEBIAN_FRONTEND=noninteractive + sudo apt-get update && sudo apt-get install -y libcurl4-openssl-dev libssl-dev python3-dev + pip install -r model-engine/requirements.txt + - install_client + - install_server + - run: + name: Run integration tests + command: | + pushd $HOME/project + kubectl port-forward svc/model-engine 5001:80 & + export GIT_TAG=$CIRCLE_SHA1 + pytest integration_tests executors: ubuntu-large: machine: - image: "ubuntu-2004:202201-02" - resource_class: xlarge + image: default + resource_class: 2xlarge commands: environment_setup: @@ -112,29 +208,30 @@ commands: install_server: description: Installs LLM Engine server steps: - - python/install-packages: - pkg-manager: pip - app-dir: server - - python/install-packages: - pkg-manager: pip - app-dir: server - pip-dependency-file: requirements-test.txt - - python/install-packages: - pkg-manager: pip - app-dir: server - pip-dependency-file: requirements_override.txt - - run: - name: Install Server - command: | - pushd server - pip install -e . - popd + - python/install-packages: + pkg-manager: pip + app-dir: model-engine + - python/install-packages: + pkg-manager: pip + app-dir: model-engine + pip-dependency-file: requirements-test.txt + - python/install-packages: + pkg-manager: pip + app-dir: model-engine + pip-dependency-file: requirements_override.txt + - run: + name: Install Server + command: | + pushd model-engine + pip install -e . + popd install_client: description: Install LLM Engine client steps: - run: name: Install LLM Engine client command: | + pip install --upgrade pip pip install -e $HOME/project/clients/python run_unit_tests_python_client: description: Unit tests of the python client @@ -159,16 +256,17 @@ commands: - run: name: Ruff Lint Check command: | - ruff . + ruff check . - run: name: Type Check command: | - pushd server + pushd model-engine mypy . --install-types --non-interactive popd - run: name: Unit Tests command: | - pushd server - WORKSPACE=.. pytest + pushd model-engine + GIT_TAG=$(git rev-parse HEAD) WORKSPACE=.. pytest --cov --cov-report=xml + diff-cover coverage.xml --compare-branch=origin/main --fail-under=80 popd diff --git a/.circleci/resources/.minikube-config-map b/.circleci/resources/.minikube-config-map new file mode 100644 index 00000000..620e3ab1 --- /dev/null +++ b/.circleci/resources/.minikube-config-map @@ -0,0 +1,5 @@ +# Configmap for AWS credentials inside minikube. +[default] +aws_access_key_id = $AWS_ACCESS_KEY_ID +aws_secret_access_key = $AWS_SECRET_ACCESS_KEY +aws_session_token = $AWS_SESSION_TOKEN \ No newline at end of file diff --git a/.circleci/resources/.minikube-registry-creds b/.circleci/resources/.minikube-registry-creds new file mode 100644 index 00000000..37f4b1fa --- /dev/null +++ b/.circleci/resources/.minikube-registry-creds @@ -0,0 +1,15 @@ +# Script to send the registry-creds addon configuration to minikube +# Source: https://github.com/kubernetes/minikube/issues/8283 +# See expect syntax here: https://manpages.ubuntu.com/manpages/trusty/man1/expect.1.html +spawn minikube addons configure registry-creds +expect "Do you want to enable AWS Elastic Container Registry?" { send "y\r" } +expect "Enter AWS Access Key ID:" { send "$AWS_ACCESS_KEY_ID\r" } +expect "Enter AWS Secret Access Key:" { send "$AWS_SECRET_ACCESS_KEY\r" } +expect "Enter AWS Session Token:" { send "$AWS_SESSION_TOKEN\r" } +expect "Enter AWS Region:" { send "us-west-2\r" } +expect "Enter 12 digit AWS Account ID (Comma separated list):" { send "$CIRCLECI_AWS_ACCOUNT_ID\r" } +expect "Enter ARN of AWS role to assume:" { send "\r" } +expect "Do you want to enable Google Container Registry?" { send "n\r" } +expect "Do you want to enable Docker Registry?" { send "n\r" } +expect "Do you want to enable Azure Container Registry?" { send "n\r" } +expect eof diff --git a/.circleci/resources/postgres-k8s.yaml b/.circleci/resources/postgres-k8s.yaml new file mode 100644 index 00000000..13d33fe9 --- /dev/null +++ b/.circleci/resources/postgres-k8s.yaml @@ -0,0 +1,50 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: postgres + labels: + app: postgres +spec: + replicas: 1 + selector: + matchLabels: + app: postgres + template: + metadata: + labels: + app: postgres + spec: + containers: + - name: main + image: "cimg/postgres:12.8-postgis" + imagePullPolicy: IfNotPresent + resources: + requests: + memory: 1Gi + cpu: 1 + ports: + - containerPort: 5432 + env: + - name: POSTGRES_USER + value: postgres + - name: POSTGRES_DB + value: circle_test + - name: POSTGRES_PASSWORD + value: circle_test + +--- + +kind: Service +apiVersion: v1 +metadata: + name: postgres + labels: + app: postgres +spec: + type: ClusterIP + selector: + app: postgres + ports: + - name: redis + port: 5432 + targetPort: 5432 diff --git a/.circleci/resources/redis-k8s.yaml b/.circleci/resources/redis-k8s.yaml new file mode 100644 index 00000000..1d3207fe --- /dev/null +++ b/.circleci/resources/redis-k8s.yaml @@ -0,0 +1,43 @@ +apiVersion: apps/v1 +kind: Deployment +metadata: + name: redis-message-broker-master + labels: + app: redis-message-broker-master +spec: + replicas: 1 + selector: + matchLabels: + app: redis-message-broker-master + template: + metadata: + labels: + app: redis-message-broker-master + spec: + containers: + - name: main + image: redis + imagePullPolicy: IfNotPresent + resources: + requests: + memory: 1Gi + cpu: 1 + ports: + - containerPort: 6379 + +--- + +kind: Service +apiVersion: v1 +metadata: + name: redis-message-broker-master + labels: + app: redis-message-broker-master +spec: + type: ClusterIP + selector: + app: redis-message-broker-master + ports: + - name: redis + port: 6379 + targetPort: 6379 diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md new file mode 100644 index 00000000..ac062d42 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -0,0 +1,41 @@ +--- +name: "\U0001F41B Bug report" +about: Submit a bug report to help us improve LLM Engine. +title: '' +labels: bug +assignees: '' + +--- + +**Describe the bug** +Thank you for taking the time to file a bug report! Before you do so, please take a look at existing open issues and make sure that your issue is not already documented. If it isn't, please provide us with a clear and concise description of what the bug is. + +**LLM Engine Version** +- LLM Engine Version: + +**System Version** +- Python Version: +- Operating System: + +**Timestamp and Request ID** +_If you ran into an internal error while using `llm-engine`, please provide the following. These fields are provided in the JSON Response when an internal error occurs._ +- `timestamp`: +- `request_id`: + +**Minimal Reproducible Example** +Steps to reproduce the behavior: +1. Install LLM Engine '....' +2. Make API call '....' +3. See error + +Please provide a code snippet that documents how your bug can be reproduced. +``` +import llmengine +... +``` + +**Expected behavior** +A clear and concise description of what you expected to happen. + +**Additional context** +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/custom.md b/.github/ISSUE_TEMPLATE/custom.md new file mode 100644 index 00000000..89130c0b --- /dev/null +++ b/.github/ISSUE_TEMPLATE/custom.md @@ -0,0 +1,10 @@ +--- +name: Custom issue template +about: If your issue doesn't fall into a bug template or feature request, please provide some information on it here. +title: '' +labels: '' +assignees: '' + +--- + + diff --git a/.github/ISSUE_TEMPLATE/feature_request.md b/.github/ISSUE_TEMPLATE/feature_request.md new file mode 100644 index 00000000..3043cd19 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/feature_request.md @@ -0,0 +1,46 @@ +--- +name: "\U0001F680 Feature request" +about: Suggest an idea for this project +title: '' +labels: enhancement +assignees: '' + +--- + +## Feature Request + +**What is the problem you're currently running into?** +A clear and concise description of what the problem is. Ex. I'm always frustrated when [...] + +**Why do you want this feature?** +A clear and concise description of why you want the feature. + +**Describe the solution you'd like** +A clear and concise description of what you want to happen. + +**Describe alternatives you've considered** +A clear and concise description of any alternative solutions or features you've considered. + +**Additional context** +Add any other context or screenshots about the feature request here. + +### Prioritization + +- **Does this feature block you from using the project?** + - [ ] Yes + - [ ] No + +- **How many users will benefit from this feature?** + - [ ] Just me + - [ ] Few people might benefit + - [ ] Many users will love it! + +- **Complexity** + - [ ] I believe it's a simple feature to implement + - [ ] It might require some effort to implement + - [ ] It's probably complex, and might take significant effort + +--- + +Thank you for your contribution to `llm-engine`. Please ensure you've given the feature considerable thought before submitting it. Once your feature request is accepted, and you're interested in building it, please mention it so that the maintainers can guide you! + diff --git a/.github/PULL_REQUEST_TEMPLATE.md b/.github/PULL_REQUEST_TEMPLATE.md new file mode 100644 index 00000000..a5eb802d --- /dev/null +++ b/.github/PULL_REQUEST_TEMPLATE.md @@ -0,0 +1,7 @@ +# Pull Request Summary + +_What is this PR changing? Why is this change being made? Any caveats you'd like to highlight? Link any relevant documents, links, or screenshots here if applicable._ + +## Test Plan and Usage Guide + +_How did you validate that your PR works correctly? How do you run or demo the code? Provide enough detail so a reviewer can reasonably reproduce the testing procedure. Paste example command line invocations if applicable._ diff --git a/.gitignore b/.gitignore index d5bec1e7..276b0676 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,4 @@ +tags *.cache *.pt *.pkl diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 7844d3df..36bf3e95 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -1,36 +1,37 @@ +fail_fast: false repos: - repo: https://github.com/psf/black # Make sure to update requirements-dev-extra.txt to match versions! - rev: 22.12.0 + rev: 24.8.0 hooks: - id: black name: "python:black" entry: black --config .black.toml - repo: https://github.com/astral-sh/ruff-pre-commit - rev: v0.0.278 + rev: v0.6.8 hooks: - id: ruff name: "python:ruff" - repo: https://github.com/pycqa/isort - rev: 5.12.0 + rev: 5.13.2 hooks: - id: isort name: "python:isort" - repo: https://github.com/jazzband/pip-tools - rev: 6.6.2 + rev: 7.4.1 hooks: - id: pip-compile - files: server/requirements\.(in|txt) + files: model-engine/requirements\.(in|txt) args: [ - server/requirements.in, + model-engine/requirements.in, --allow-unsafe, --no-emit-index-url, --no-emit-trusted-host, --index-url=https://pypi.org/simple, ] - repo: https://github.com/pre-commit/pre-commit-hooks - rev: v4.4.0 # https://github.com/pre-commit/pre-commit-hooks/releases + rev: v4.6.0 # https://github.com/pre-commit/pre-commit-hooks/releases hooks: - id: check-added-large-files args: @@ -49,3 +50,31 @@ repos: language: python - id: check-toml language: python + - repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v1.11.2' # Make sure this matches the version in requirements-dev.txt! + hooks: + - id: mypy + name: mypy-clients-python + files: clients/python/.* + entry: mypy --config-file clients/python/mypy.ini + language: system + - repo: https://github.com/pre-commit/mirrors-mypy + rev: 'v1.11.2' # Make sure this matches the version in requirements-dev.txt! + hooks: + - id: mypy + name: mypy-server + entry: mypy --config-file model-engine/mypy.ini + language: system + - repo: local + hooks: + - id: trufflehog + name: TruffleHog + description: Detect secrets in your data. + entry: bash -c 'docker run --rm -v "$(pwd)/..:/workdir" -i --rm trufflesecurity/trufflehog:latest git file:///workdir/llm-engine --since-commit HEAD --only-verified --fail' + language: system + stages: ["commit", "push"] + - repo: https://github.com/returntocorp/semgrep + rev: 'v1.89.0' + hooks: + - id: semgrep + args: [ '--config', 'p/python', '--error' ] diff --git a/.ruff.toml b/.ruff.toml index af1d91d6..69f83253 100644 --- a/.ruff.toml +++ b/.ruff.toml @@ -2,3 +2,4 @@ line-length = 100 ignore = ["E501"] +exclude = ["gen", "alembic"] diff --git a/LICENSE b/LICENSE index d803528b..b8106a2f 100644 --- a/LICENSE +++ b/LICENSE @@ -1,14 +1,201 @@ -Copyright [2023] [Scale AI] + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION - http://www.apache.org/licenses/LICENSE-2.0 + 1. Definitions. -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright 2023 Scale AI + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index 51feafdb..39fc7098 100644 --- a/README.md +++ b/README.md @@ -1,6 +1,10 @@ -# ⚡ LLM Engine ⚡ +# LLM Engine -**The open source engine for fine-tuning large language models**. +[![LICENSE](https://img.shields.io/github/license/scaleapi/llm-engine.svg)](https://github.com/scaleapi/llm-engine/blob/master/LICENSE) +[![Release Notes](https://img.shields.io/github/release/scaleapi/llm-engine)](https://github.com/scaleapi/llm-engine/releases) +[![CircleCI](https://circleci.com/gh/scaleapi/llm-engine.svg?style=shield)](https://circleci.com/gh/scaleapi/llm-engine) + +🚀 **The open source engine for fine-tuning and serving large language models**. 🚀 Scale's LLM Engine is the easiest way to customize and serve LLMs. In LLM Engine, models can be accessed via Scale's hosted version or by using the Helm charts in this repository to run model inference and fine-tuning in your own infrastructure. @@ -87,4 +91,4 @@ print(response.output.text) You should see a successful completion of your given prompt! _What's next?_ Visit the [LLM Engine documentation pages](https://scaleapi.github.io/llm-engine/) for more on -the `Completion` and `FineTune` APIs and how to use them. +the `Completion` and `FineTune` APIs and how to use them. Check out this [blog post](https://scale.com/blog/fine-tune-llama-2) for an end-to-end example. diff --git a/charts/llm-engine/README.md b/charts/llm-engine/README.md deleted file mode 100644 index 9281c374..00000000 --- a/charts/llm-engine/README.md +++ /dev/null @@ -1,3 +0,0 @@ -# LLM Engine Helm Chart - -This chart contains k8s templates for deploying LLM Engine to a k8s cluster. diff --git a/charts/llm-engine/templates/_helpers.tpl b/charts/llm-engine/templates/_helpers.tpl deleted file mode 100644 index 04c8168f..00000000 --- a/charts/llm-engine/templates/_helpers.tpl +++ /dev/null @@ -1,367 +0,0 @@ -{{/* -Expand the name of the chart. -*/}} -{{- define "llmEngine.name" -}} -{{- default .Chart.Name | trunc 63 | trimSuffix "-" }} -{{- end }} - -{{/* -Create a default fully qualified app name. -We truncate at 40 chars because some Kubernetes name fields are limited to 63 (by the DNS naming spec). -If release name contains chart name it will be used as a full name. -*/}} -{{- define "llmEngine.fullname" -}} -{{- if .Values.serviceIdentifier }} -{{- printf "%s-%s" .Chart.Name .Values.serviceIdentifier | trunc 40 | trimSuffix "-" }} -{{- else }} -{{- default .Chart.Name | trunc 40 | trimSuffix "-" }} -{{- end }} -{{- end }} - -{{- define "llmEngine.buildername" -}} -"{{ include "llmEngine.fullname" . }}-endpoint-builder" -{{- end }} - -{{- define "llmEngine.cachername" -}} -"{{ include "llmEngine.fullname" . }}-cacher" -{{- end }} - -{{/* -Create chart name and version as used by the chart label. -*/}} -{{- define "llmEngine.chart" -}} -{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} -{{- end }} - -{{/* -Common labels -*/}} -{{- define "llmEngine.labels" -}} -team: infra -product: llm-engine -helm.sh/chart: {{ include "llmEngine.chart" . }} -app.kubernetes.io/managed-by: {{ .Release.Service }} -app.kubernetes.io/version: {{ .Values.tag }} -tags.datadoghq.com/version: {{ .Values.tag }} -tags.datadoghq.com/env: {{ .Values.context }} -{{- end }} - -{{- define "llmEngine.selectorLabels.builder" -}} -app: {{ include "llmEngine.buildername" . }} -{{- end }} - -{{- define "llmEngine.selectorLabels.cacher" -}} -app: {{ include "llmEngine.cachername" . }} -{{- end }} - -{{- define "llmEngine.selectorLabels.gateway" -}} -app: {{ include "llmEngine.fullname" . -}} -{{- end }} - -{{- define "llmEngine.baseTemplateLabels" -}} -user_id: ${OWNER} -team: ${TEAM} -product: ${PRODUCT} -created_by: ${CREATED_BY} -owner: ${OWNER} -env: {{- .Values.context | printf " %s" }} -managed-by: {{- include "llmEngine.fullname" . | printf " %s\n" -}} -use_scale_llm_engine_endpoint_network_policy: "true" -tags.datadoghq.com/env: {{- .Values.context | printf " %s" }} -tags.datadoghq.com/version: {{- .Values.tag | printf " %s" }} -{{- end }} - -{{- define "llmEngine.serviceTemplateLabels" -}} -{{- include "llmEngine.baseTemplateLabels" . | printf "%s\n" -}} -tags.datadoghq.com/service: ${ENDPOINT_NAME} -endpoint_id: ${ENDPOINT_ID} -endpoint_name: ${ENDPOINT_NAME} -{{- end }} - -{{- define "llmEngine.jobTemplateLabels" -}} -{{- include "llmEngine.baseTemplateLabels" . | printf "%s\n" -}} -llm_engine_job_id: ${JOB_ID} -tags.datadoghq.com/service: ${JOB_ID} -{{- end }} - -{{- define "llmEngine.serviceTemplateAsyncAnnotations" -}} -celery.scaleml.autoscaler/queue: ${QUEUE} -celery.scaleml.autoscaler/broker: ${BROKER_NAME} -celery.scaleml.autoscaler/taskVisibility: "VISIBILITY_24H" -celery.scaleml.autoscaler/perWorker: "${PER_WORKER}" -celery.scaleml.autoscaler/minWorkers: "${MIN_WORKERS}" -celery.scaleml.autoscaler/maxWorkers: "${MAX_WORKERS}" -{{- end }} - -{{- define "llmEngine.serviceTemplateAffinity" -}} -podAffinity: - preferredDuringSchedulingIgnoredDuringExecution: - - weight: 1 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: app - operator: In - values: - - ${RESOURCE_NAME} - topologyKey: kubernetes.io/hostname - - weight: 100 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: ${IMAGE_HASH} - operator: In - values: - - "True" - topologyKey: kubernetes.io/hostname -{{- end }} - -{{- define "llmEngine.baseServiceTemplateEnv" -}} -env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" - - name: DD_SERVICE - value: "${ENDPOINT_NAME}" - - name: DD_ENV - value: {{ .Values.context }} - - name: DD_VERSION - value: {{ .Values.tag }} - - name: DD_AGENT_HOST - valueFrom: - fieldRef: - fieldPath: status.hostIP - - name: OMP_NUM_THREADS - value: "1" - - name: BASE_PATH - value: "${BASE_PATH}" - - name: BUNDLE_URL - value: "${BUNDLE_URL}" - - name: LOAD_PREDICT_FN_MODULE_PATH - value: "${LOAD_PREDICT_FN_MODULE_PATH}" - - name: LOAD_MODEL_FN_MODULE_PATH - value: "${LOAD_MODEL_FN_MODULE_PATH}" - - name: AWS_PROFILE - value: "${AWS_ROLE}" - - name: RESULTS_S3_BUCKET - value: "${RESULTS_S3_BUCKET}" - - name: CHILD_FN_INFO - value: "${CHILD_FN_INFO}" - - name: PREWARM - value: "${PREWARM}" - - name: ML_INFRA_SERVICES_CONFIG_PATH - {{- if .Values.config.file }} - value: "${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs/{{ .Values.config.file.infra }}" - {{- else }} - value: "${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" - {{- end }} -{{- end }} - -{{- define "llmEngine.syncServiceTemplateEnv" -}} -{{- include "llmEngine.baseServiceTemplateEnv" . }} - - name: PORT - value: "${ARTIFACT_LIKE_CONTAINER_PORT}" -{{- end }} - -{{- define "llmEngine.asyncServiceTemplateEnv" -}} -{{- include "llmEngine.baseServiceTemplateEnv" . }} - - name: CELERY_S3_BUCKET - value: "${CELERY_S3_BUCKET}" - - name: BROKER_TYPE - value: "${BROKER_TYPE}" - - name: SQS_PROFILE - value: "${SQS_PROFILE}" - - name: SQS_QUEUE_NAME - value: "${QUEUE}" - - name: SQS_QUEUE_URL - value: "${SQS_QUEUE_URL}" -{{- end }} - -{{- define "llmEngine.baseForwarderTemplateEnv" -}} -env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" - - name: DD_SERVICE - value: "${ENDPOINT_NAME}" - - name: DD_ENV - value: {{ .Values.context }} - - name: DD_VERSION - value: {{ .Values.tag }} - - name: DD_AGENT_HOST - valueFrom: - fieldRef: - fieldPath: status.hostIP - - name: AWS_PROFILE - value: "${AWS_ROLE}" - - name: RESULTS_S3_BUCKET - value: "${RESULTS_S3_BUCKET}" - - name: BASE_PATH - value: "/workspace" - - name: ML_INFRA_SERVICES_CONFIG_PATH - {{- if .Values.config.file }} - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/{{ .Values.config.file.infra }}" - {{- else }} - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" - {{- end }} -{{- end }} - -{{- define "llmEngine.syncForwarderTemplateEnv" -}} -{{- include "llmEngine.baseForwarderTemplateEnv" . }} -{{- if and .Values.forwarder .Values.forwarder.forceUseIPv4 }} - - name: HTTP_HOST - value: "0.0.0.0" -{{- end }} -{{- end }} - -{{- define "llmEngine.asyncForwarderTemplateEnv" -}} -{{- include "llmEngine.baseForwarderTemplateEnv" . }} - - name: CELERY_QUEUE - value: "${QUEUE}" - - name: CELERY_TASK_VISIBILITY - value: "VISIBILITY_24H" - - name: S3_BUCKET - value: "${CELERY_S3_BUCKET}" -{{- end }} - -{{- define "llmEngine.serviceEnv" }} -env: - - name: DATADOG_TRACE_ENABLED - value: "{{ .Values.datadog_trace_enabled }}" - - name: DD_ENV - value: {{ .Values.context }} - - name: DD_VERSION - value: {{ .Values.tag }} - - name: DD_AGENT_HOST - valueFrom: - fieldRef: - fieldPath: status.hostIP - - name: GIT_TAG - value: {{ .Values.tag }} - - name: SERVICE_IDENTIFIER - {{- if .Values.serviceIdentifier }} - value: {{ .Values.serviceIdentifier }} - {{- end }} - {{- if .Values.aws }} - - name: AWS_PROFILE - value: {{ .Values.aws.profileName }} - - name: ECR_READ_AWS_PROFILE - value: {{ .Values.aws.profileName }} - {{- end }} - {{- with .Values.secrets }} - {{- if .kubernetesDatabaseSecretName }} - - name: ML_INFRA_DATABASE_URL - valueFrom: - secretKeyRef: - name: {{ .kubernetesDatabaseSecretName }} - key: database_url - {{- else if .awsDatabaseSecretName }} - - name: DB_SECRET_NAME - value: {{ .awsDatabaseSecretName }} - {{- end }} - {{- end }} - {{- if .Values.config.file }} - - name: DEPLOY_SERVICE_CONFIG_PATH - value: "/workspace/llm_engine/service_configs/{{ .Values.config.file.llm_engine }}" - - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/{{ .Values.config.file.infra }}" - {{- else }} - - name: DEPLOY_SERVICE_CONFIG_PATH - value: "/workspace/llm_engine/service_configs/service_config.yaml" - - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" - {{- end }} - - name: CELERY_ELASTICACHE_ENABLED - value: "true" - - name: LLM_ENGINE_SERVICE_TEMPLATE_FOLDER - value: "/workspace/llm_engine/llm_engine/infra/gateways/resources/templates" -{{- end }} - -{{- define "llmEngine.gatewayEnv" }} -{{- include "llmEngine.serviceEnv" . }} - - name: DD_SERVICE - value: {{- printf " %s" (include "llmEngine.fullname" .) }} -{{- end }} - -{{- define "llmEngine.builderEnv" }} -{{- include "llmEngine.serviceEnv" . }} - - name: DD_SERVICE - value: {{- printf " %s" (include "llmEngine.buildername" .) }} -{{- end }} - -{{- define "llmEngine.cacherEnv" }} -{{- include "llmEngine.serviceEnv" . }} - - name: DD_SERVICE - value: {{- printf " %s" (include "llmEngine.cachername" .) }} -{{- end }} - -{{- define "llmEngine.volumes" }} -volumes: - - name: dshm - emptyDir: - medium: Memory - - name: service-template-config - configMap: - name: {{ include "llmEngine.fullname" . }}-service-template-config - {{- if .Values.aws }} - - name: config-volume - configMap: - name: {{ .Values.aws.configMap.name }} - {{- end }} - {{- if .Values.config.values }} - - name: llm-engine-service-config-volume - configMap: - name: {{ include "llmEngine.fullname" . }}-service-config - items: - - key: llm_engine_service_config - path: service_config.yaml - - name: infra-service-config-volume - configMap: - name: {{ include "llmEngine.fullname" . }}-service-config - items: - - key: infra_service_config - path: config.yaml - {{- end }} -{{- end }} - -{{- define "llmEngine.volumeMounts" }} -volumeMounts: - - name: dshm - mountPath: /dev/shm - - name: service-template-config - mountPath: /workspace/llm_engine/llm_engine/infra/gateways/resources/templates - {{- if .Values.aws }} - - name: config-volume - mountPath: /home/user/.aws/config - subPath: config - {{- end }} - {{- if .Values.config.values }} - - name: llm-engine-service-config-volume - mountPath: /workspace/llm_engine/service_configs - - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs - {{- end }} -{{- end }} - -{{- define "llmEngine.forwarderVolumeMounts" }} -volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: user-config - mountPath: /workspace/user_config - subPath: raw_data - - name: endpoint-config - mountPath: /workspace/endpoint_config - subPath: raw_data - {{- if .Values.config.values }} - - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs - {{- end }} -{{- end }} - -{{- define "llmEngine.serviceAccountNamespaces" }} -namespaces: - - {{ .Release.Namespace }} -{{- range .Values.serviceAccount.namespaces }} - - {{ . }} -{{- end }} -{{- end }} diff --git a/charts/llm-engine/templates/aws_config_map.yaml b/charts/llm-engine/templates/aws_config_map.yaml deleted file mode 100644 index 48e2c30a..00000000 --- a/charts/llm-engine/templates/aws_config_map.yaml +++ /dev/null @@ -1,18 +0,0 @@ -{{- if .Values.aws }} -{{- if eq .Values.aws.configMap.create true }} -apiVersion: v1 -kind: ConfigMap -metadata: - name: {{ .Values.aws.configMap.name }} - labels: - {{- include "llmEngine.labels" . | nindent 4 }} - annotations: - "helm.sh/hook": pre-install,pre-upgrade - "helm.sh/hook-weight": "-2" -data: - config: |- - [profile {{ .Values.aws.profileName }}] - role_arn = {{ index .Values.serviceAccount.annotations "eks.amazonaws.com/role-arn" }} - web_identity_token_file = /var/run/secrets/eks.amazonaws.com/serviceaccount/token -{{- end }} -{{- end }} diff --git a/charts/llm-engine/templates/balloon_a100_deployment.yaml b/charts/llm-engine/templates/balloon_a100_deployment.yaml deleted file mode 100644 index 0559c2f6..00000000 --- a/charts/llm-engine/templates/balloon_a100_deployment.yaml +++ /dev/null @@ -1,48 +0,0 @@ -{{- if not .Values.serviceIdentifier }} -apiVersion: apps/v1 -kind: Deployment -metadata: - name: llm-engine-balloon-a100 - labels: - team: infra - product: common-warm-nodes -spec: - replicas: {{ .Values.replicaCount.balloonA100 }} - selector: - matchLabels: - app: llm-engine-balloon-a100 - version: v1 - template: - metadata: - labels: - app: llm-engine-balloon-a100 - product: common-warm-nodes - team: infra - env: {{ .Values.context }} - version: v1 - annotations: - sidecar.istio.io/inject: "false" - spec: - nodeSelector: - k8s.amazonaws.com/accelerator: nvidia-ampere-a100 - node-lifecycle: normal - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - containers: - - image: public.ecr.aws/ubuntu/ubuntu:latest - imagePullPolicy: IfNotPresent - name: main - resources: - limits: - memory: 28Gi - nvidia.com/gpu: 1 - cpu: 4 - command: - - /bin/bash - - -c - - "while true; do sleep 30; done" - terminationGracePeriodSeconds: 0 - priorityClassName: llm-engine-low-priority -{{- end }} diff --git a/charts/llm-engine/templates/balloon_t4_deployment.yaml b/charts/llm-engine/templates/balloon_t4_deployment.yaml deleted file mode 100644 index 8e871d06..00000000 --- a/charts/llm-engine/templates/balloon_t4_deployment.yaml +++ /dev/null @@ -1,48 +0,0 @@ -{{- if not .Values.serviceIdentifier }} -apiVersion: apps/v1 -kind: Deployment -metadata: - name: llm-engine-balloon-t4 - labels: - team: infra - product: common-warm-nodes -spec: - replicas: {{ .Values.replicaCount.balloonT4 }} - selector: - matchLabels: - app: llm-engine-balloon-t4 - version: v1 - template: - metadata: - labels: - app: llm-engine-balloon-t4 - product: common-warm-nodes - team: infra - env: {{ .Values.context }} - version: v1 - annotations: - sidecar.istio.io/inject: "false" - spec: - nodeSelector: - k8s.amazonaws.com/accelerator: nvidia-tesla-t4 - node-lifecycle: normal - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - containers: - - image: public.ecr.aws/ubuntu/ubuntu:latest - imagePullPolicy: IfNotPresent - name: main - resources: - limits: - memory: 28Gi - nvidia.com/gpu: 1 - cpu: 4 - command: - - /bin/bash - - -c - - "while true; do sleep 30; done" - terminationGracePeriodSeconds: 0 - priorityClassName: llm-engine-low-priority -{{- end }} diff --git a/charts/llm-engine/templates/gateway_service.yaml b/charts/llm-engine/templates/gateway_service.yaml deleted file mode 100644 index 9a3497c1..00000000 --- a/charts/llm-engine/templates/gateway_service.yaml +++ /dev/null @@ -1,15 +0,0 @@ -apiVersion: v1 -kind: Service -metadata: - name: {{ include "llmEngine.fullname" . }} - labels: - {{- include "llmEngine.labels" . | nindent 4 }} -spec: - type: {{ .Values.service.type }} - ports: - - port: {{ .Values.service.port }} - targetPort: http - protocol: TCP - name: http - selector: - {{- include "llmEngine.selectorLabels.gateway" . | nindent 4 }} diff --git a/charts/llm-engine/templates/launch_default_priority_class.yaml b/charts/llm-engine/templates/launch_default_priority_class.yaml deleted file mode 100644 index 1217c7c1..00000000 --- a/charts/llm-engine/templates/launch_default_priority_class.yaml +++ /dev/null @@ -1,11 +0,0 @@ -{{- if not .Values.serviceIdentifier }} -apiVersion: scheduling.k8s.io/v1 -kind: PriorityClass -metadata: - name: "{{ include "llmEngine.fullname" . }}-default-priority" -value: 1 -# This ensures that the default llm-engine pods will never preempt any pods, which means -# they cannot take advantage of the dummy nodes. -preemptionPolicy: Never -description: "Default Priority Class for LLMEngine" -{{- end }} diff --git a/charts/llm-engine/templates/service_config_map.yaml b/charts/llm-engine/templates/service_config_map.yaml deleted file mode 100644 index 003447dd..00000000 --- a/charts/llm-engine/templates/service_config_map.yaml +++ /dev/null @@ -1,25 +0,0 @@ -{{- if .Values.config.values }} -apiVersion: v1 -kind: ConfigMap -metadata: - name: {{ include "llmEngine.fullname" . }}-service-config - labels: - {{- include "llmEngine.labels" . | nindent 4 }} - annotations: - "helm.sh/hook": pre-install,pre-upgrade - "helm.sh/hook-weight": "-2" -data: - llm_engine_service_config: |- - {{- with .Values.config.values.llm_engine }} - {{- range $key, $value := . }} - {{ $key }}: {{ $value | quote }} - {{- end }} - {{- end }} - infra_service_config: |- - env: {{ .Values.context | quote }} - {{- with .Values.config.values.infra }} - {{- range $key, $value := . }} - {{ $key }}: {{ $value | quote }} - {{- end }} - {{- end }} -{{- end }} diff --git a/charts/llm-engine/templates/service_template_config_map.yaml b/charts/llm-engine/templates/service_template_config_map.yaml deleted file mode 100644 index 87b992cf..00000000 --- a/charts/llm-engine/templates/service_template_config_map.yaml +++ /dev/null @@ -1,744 +0,0 @@ -{{- $llm_engine_name := include "llmEngine.fullname" . }} -{{- $config_values := .Values.config.values }} -{{- $forwarder_repository := .Values.image.forwarderRepository -}} -{{- $triton_repository := .Values.triton.image.repository -}} -{{- $triton_tag := .Values.triton.image.tag -}} -{{- $env := .Values.context -}} -{{- $service_template_labels := include "llmEngine.serviceTemplateLabels" . }} -{{- $job_template_labels := include "llmEngine.jobTemplateLabels" . }} -{{- $service_env := include "llmEngine.serviceEnv" . }} -{{- $async_service_template_env := include "llmEngine.asyncServiceTemplateEnv" . }} -{{- $sync_service_template_env := include "llmEngine.syncServiceTemplateEnv" . }} -{{- $async_forwarder_template_env := include "llmEngine.asyncForwarderTemplateEnv" . }} -{{- $sync_forwarder_template_env := include "llmEngine.syncForwarderTemplateEnv" . }} -{{- $forwarder_volume_mounts := include "llmEngine.forwarderVolumeMounts" . }} -{{- $gateway_repository := .Values.image.gatewayRepository -}} -{{- $tag := .Values.tag -}} -{{- $aws_config_map_name := .Values.aws.configMap.name }} -{{- $security_context := .Values.serviceTemplate.securityContext }} -{{- $mount_infra_config := .Values.serviceTemplate.mountInfraConfig }} -{{- $service_template_service_account_name := .Values.serviceTemplate.serviceAccountName }} -{{- $service_template_aws_config_map_name := .Values.serviceTemplate.awsConfigMapName }} -{{- $celery_broker_type := .Values.celeryBrokerType }} - -{{- if .Values.message }} -{{- .Values.message }} -{{- end }} -apiVersion: v1 -kind: ConfigMap -metadata: - name: {{ $llm_engine_name }}-service-template-config - labels: - {{- include "llmEngine.labels" . | nindent 4 }} - annotations: - "helm.sh/hook": pre-install,pre-upgrade - "helm.sh/hook-weight": "-2" -data: - {{- range $device := tuple "cpu" "gpu" }} - {{- range $mode := tuple "async" "sync" "streaming"}} - {{- range $flavor := tuple "triton-enhanced-runnable-image" "runnable-image" "artifact" }} - {{- if or (ne $mode "streaming") (eq $flavor "runnable-image") }} - deployment-{{ $flavor }}-{{ $mode }}-{{ $device }}.yaml: |- - apiVersion: apps/v1 - kind: Deployment - metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} - labels: - {{- $service_template_labels | nindent 8 }} - {{- if eq $mode "async" }} - annotations: - {{- include "llmEngine.serviceTemplateAsyncAnnotations" . | nindent 8 }} - {{- end }} - spec: - strategy: - type: RollingUpdate - rollingUpdate: - maxSurge: 1 - maxUnavailable: 0 - replicas: ${MIN_WORKERS} - selector: - matchLabels: - app: ${RESOURCE_NAME} - version: v1 - template: - metadata: - labels: - app: ${RESOURCE_NAME} - {{- $service_template_labels | nindent 12 }} - {{- if eq $mode "async" }} - sidecar.istio.io/inject: "false" # TODO: switch to scuttle - {{- end }} - version: v1 - annotations: - ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' - kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" - spec: - affinity: - {{- include "llmEngine.serviceTemplateAffinity" . | nindent 12 }} - terminationGracePeriodSeconds: 600 - {{- if $service_template_service_account_name }} - serviceAccount: {{ $service_template_service_account_name }} - {{- else }} - serviceAccount: {{ $llm_engine_name }} - {{- end }} - nodeSelector: - node-lifecycle: normal - {{- if eq $device "gpu" }} - k8s.amazonaws.com/accelerator: ${GPU_TYPE} - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - {{- end }} - priorityClassName: ${PRIORITY} - containers: - {{- if eq $flavor "artifact" }} - - image: ${IMAGE} - imagePullPolicy: IfNotPresent - name: main - {{- with $security_context }} - securityContext: - {{- toYaml . | nindent 16 }} - {{- end }} - {{- if eq $mode "async" }} - {{- $async_service_template_env | nindent 14 }} - {{- else if eq $mode "sync" }} - {{- $sync_service_template_env | nindent 14 }} - {{- end }} - readinessProbe: - {{- if eq $mode "async" }} - exec: - command: - - cat - - /tmp/readyz - {{- else if eq $mode "sync" }} - httpGet: - path: /readyz - port: ${ARTIFACT_LIKE_CONTAINER_PORT} - {{- end }} - initialDelaySeconds: 2 - periodSeconds: 2 - failureThreshold: 100 - command: [ "dumb-init", "--", "ddtrace-run" ] - {{- if eq $mode "async" }} - # Not including --pool=solo means there's a worker process and a separate supervisor process - # meaning if the worker crashes (because of OOM or something) the supervisor process can mark the task as - # failed, which should get rid of infinite task retries - args: - - celery - - --app=llm_engine.inference.async_inference - - worker - - --loglevel=INFO - - --concurrency=1 - - --queues=${QUEUE} - - -O - - fair - {{- else if eq $mode "sync" }} - args: - - python - - -m - - llm_engine.inference.sync_inference.start_fastapi_server - {{- end }} - resources: - requests: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - limits: - {{- if eq $device "gpu" }} - nvidia.com/gpu: ${GPUS} - {{- end }} - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - - name: user-config - mountPath: ${BASE_PATH}/user_config - subPath: raw_data - - name: endpoint-config - mountPath: ${BASE_PATH}/endpoint_config - subPath: raw_data - {{- if $config_values }} - - name: infra-service-config-volume - mountPath: ${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs - {{- end }} - {{- else if contains "runnable-image" $flavor }} - {{- if eq $mode "sync" }} - - name: http-forwarder - image: {{ $forwarder_repository }}:${FORWARDER_IMAGE_TAG} - imagePullPolicy: IfNotPresent - command: - - /usr/bin/dumb-init - - -- - - ddtrace-run - - run-service - - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - - --http - - production_threads - - --port - - "${FORWARDER_PORT}" - - --concurrency - - "${PER_WORKER}" - - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" - - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" - {{- $sync_forwarder_template_env | nindent 14 }} - readinessProbe: - httpGet: - path: /readyz - port: ${FORWARDER_PORT} - initialDelaySeconds: ${READINESS_INITIAL_DELAY} - periodSeconds: 5 - resources: - requests: - cpu: 0.1 - memory: "100M" - ephemeral-storage: "100M" - limits: - cpu: ${FORWARDER_CPUS_LIMIT} - memory: ${FORWARDER_MEMORY_LIMIT} - ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} - {{ $forwarder_volume_mounts | nindent 14 }} - ports: - - containerPort: ${FORWARDER_PORT} - name: http - {{- else if eq $mode "streaming" }} - - name: http-forwarder - image: {{ $forwarder_repository }}:{{ $tag }} - imagePullPolicy: IfNotPresent - command: - - /usr/bin/dumb-init - - -- - - ddtrace-run - - python - - -m - - llm_engine.inference.forwarding.http_forwarder - - --config - - /workspace/llm_engine/llm_engine/inference/configs/service--http_forwarder.yaml - - --port - - "${FORWARDER_PORT}" - - --num-workers - - "${PER_WORKER}" - - --set - - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - - --set - - "forwarder.stream.predict_route=${STREAMING_PREDICT_ROUTE}" - - --set - - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --set - - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" - {{- $sync_forwarder_template_env | nindent 14 }} - readinessProbe: - httpGet: - path: /readyz - port: ${FORWARDER_PORT} - initialDelaySeconds: ${READINESS_INITIAL_DELAY} - periodSeconds: 5 - resources: - requests: - cpu: 0.1 - memory: "100M" - ephemeral-storage: "100M" - limits: - cpu: ${FORWARDER_CPUS_LIMIT} - memory: ${FORWARDER_MEMORY_LIMIT} - ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} - {{ $forwarder_volume_mounts | nindent 14 }} - ports: - - containerPort: ${FORWARDER_PORT} - name: http - {{- else if eq $mode "async" }} - - name: celery-forwarder - image: {{ $forwarder_repository }}:${FORWARDER_IMAGE_TAG} - imagePullPolicy: IfNotPresent - command: - - /usr/bin/dumb-init - - -- - - ddtrace-run - - run-service - - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - - --queue - - "${QUEUE}" - - --task-visibility - - "VISIBILITY_24H" - - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" - - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" - {{- if eq $celery_broker_type "sqs" }} - - --sqs-url - - "${SQS_QUEUE_URL}" - {{- end }} - - --concurrency - - "${PER_WORKER}" - {{- $async_forwarder_template_env | nindent 14 }} - resources: - requests: - cpu: 0.1 - memory: "100M" - ephemeral-storage: "100M" - limits: - cpu: ${FORWARDER_CPUS_LIMIT} - memory: ${FORWARDER_MEMORY_LIMIT} - ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} - {{ $forwarder_volume_mounts | nindent 14 }} - {{- end }} - {{- if eq $flavor "triton-enhanced-runnable-image" }} - - name: tritonserver - image: {{ $triton_repository }}:${TRITON_COMMIT_TAG}-triton - imagePullPolicy: IfNotPresent - command: - - /usr/bin/dumb-init - - -- - - bash - - -c - - "$TRITON_COMMAND" - env: - - name: AWS_PROFILE - value: "${AWS_ROLE}" - ports: - - containerPort: 8000 - name: http - - containerPort: 8001 - name: grpc - - containerPort: 8002 - name: metrics - readinessProbe: - httpGet: - # Need to have Triton support --http-address IPv6 :( - # https://github:com/triton-inference-server/server/issues/5305: - # path: /v2/health/ready - # port: 8000 - path: /readyz - port: 3000 - initialDelaySeconds: $TRITON_READINESS_INITIAL_DELAY - periodSeconds: 10 - resources: - requests: - cpu: ${TRITON_CPUS} - ${TRITON_MEMORY_DICT} - ${TRITON_STORAGE_DICT} - limits: - cpu: ${TRITON_CPUS} - ${TRITON_MEMORY_DICT} - ${TRITON_STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - mountPath: /dev/shm - name: dshm - {{- end }} - - name: main - {{- with $security_context }} - securityContext: - {{- toYaml . | nindent 16 }} - {{- end }} - image: ${IMAGE} - imagePullPolicy: IfNotPresent - command: ${COMMAND} - env: ${MAIN_ENV} - readinessProbe: - httpGet: - path: ${HEALTHCHECK_ROUTE} - port: ${USER_CONTAINER_PORT} - initialDelaySeconds: ${READINESS_INITIAL_DELAY} - periodSeconds: 5 - resources: - requests: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - limits: - {{- if eq $device "gpu" }} - nvidia.com/gpu: ${GPUS} - {{- end }} - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - mountPath: /dev/shm - name: dshm - {{- if $mount_infra_config }} - - name: infra-service-config-volume - mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - {{- end }} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - - name: user-config - mountPath: /app/user_config - subPath: raw_data - - name: endpoint-config - mountPath: /app/endpoint_config - subPath: raw_data - ports: - - containerPort: ${USER_CONTAINER_PORT} - name: http - {{- end }} - # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 - securityContext: - fsGroup: 65534 - volumes: - - name: config-volume - configMap: - {{- if $service_template_aws_config_map_name }} - name: {{ $service_template_aws_config_map_name }} - {{- else }} - name: {{ $aws_config_map_name }} - {{- end }} - - name: user-config - configMap: - name: ${RESOURCE_NAME} - - name: endpoint-config - configMap: - name: ${RESOURCE_NAME}-endpoint-config - - name: dshm - emptyDir: - medium: Memory - {{- if $config_values }} - - name: infra-service-config-volume - configMap: - name: {{ $llm_engine_name }}-service-config - items: - - key: infra_service_config - path: config.yaml - {{- end }} - {{- end }} - {{- end }} - {{- end }} - {{- end }} - user-config.yaml: |- - apiVersion: v1 - kind: ConfigMap - metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} - labels: - {{- $service_template_labels | nindent 8 }} - data: - raw_data: ${CONFIG_DATA_SERIALIZED} - endpoint-config.yaml: |- - apiVersion: v1 - kind: ConfigMap - metadata: - name: ${RESOURCE_NAME}-endpoint-config - namespace: ${NAMESPACE} - labels: - {{- $service_template_labels | nindent 8 }} - data: - raw_data: ${ENDPOINT_CONFIG_SERIALIZED} - horizontal-pod-autoscaler.yaml: |- - apiVersion: ${API_VERSION} - kind: HorizontalPodAutoscaler - metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} - labels: - {{- $service_template_labels | nindent 8 }} - spec: - minReplicas: ${MIN_WORKERS} - maxReplicas: ${MAX_WORKERS} - scaleTargetRef: - apiVersion: apps/v1 - kind: Deployment - name: ${RESOURCE_NAME} - metrics: - - type: Pods - pods: - metric: - name: request-concurrency-average - target: - type: Value - averageValue: ${CONCURRENCY} - service.yaml: |- - apiVersion: v1 - kind: Service - metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} - labels: - {{- $service_template_labels | nindent 8 }} - spec: - type: ${SERVICE_TYPE} - selector: - app: ${RESOURCE_NAME} - ports: - - port: 80 - targetPort: ${SERVICE_TARGET_PORT} - protocol: TCP - name: http - ${NODE_PORT_DICT} - vertical-pod-autoscaler.yaml: |- - apiVersion: "autoscaling.k8s.io/v1" - kind: VerticalPodAutoscaler - metadata: - name: ${RESOURCE_NAME} - labels: - {{- $service_template_labels | nindent 8 }} - spec: - targetRef: - apiVersion: "apps/v1" - kind: Deployment - name: ${RESOURCE_NAME} - updatePolicy: - updateMode: "Auto" - resourcePolicy: - containerPolicies: - - containerName: istio-proxy - mode: "Off" - - containerName: main - minAllowed: - cpu: 100m - memory: 128Mi - maxAllowed: - cpu: ${CPUS} - memory: ${MEMORY} - controlledResources: ["cpu", "memory"] - batch-job-orchestration-job.yaml: |- - apiVersion: batch/v1 - kind: Job - metadata: - name: ${RESOURCE_NAME} - labels: - {{- $job_template_labels | nindent 8 }} - spec: - backoffLimit: 0 - activeDeadlineSeconds: ${BATCH_JOB_MAX_RUNTIME} - ttlSecondsAfterFinished: ${BATCH_JOB_TTL_SECONDS_AFTER_FINISHED} - template: - metadata: - labels: - {{- $job_template_labels | nindent 12 }} - sidecar.istio.io/inject: "false" - version: v1 - annotations: - ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:{{ $env }}", "llm_engine_job_id:${JOB_ID}"]}]' - cluster-autoscaler.kubernetes.io/safe-to-evict: "false" - spec: - restartPolicy: Never - nodeSelector: - node-lifecycle: normal - serviceAccountName: {{ $llm_engine_name }} - volumes: - - name: config-volume - configMap: - name: {{ $aws_config_map_name }} - containers: - - name: main - image: {{ $gateway_repository }}:{{ $tag }} - env: - - name: DD_SERVICE - value: ${RESOURCE_NAME} - {{- $env_vars := include "llmEngine.serviceEnv" . | fromYaml }} - {{- range $env_var := index $env_vars "env" }} - {{- $env_var_name := index $env_var "name" }} - {{- if ne $env_var_name "DD_SERVICE" }} - {{- tuple $env_var | toYaml | nindent 16 }} - {{- end }} - {{- end }} - imagePullPolicy: Always - command: - - dumb-init - - -- - - ddtrace-run - args: - - python - - -m - - server.llm_engine_server.entrypoints.start_batch_job_orchestration - - --job-id - - ${JOB_ID} - - --owner - - ${OWNER} - - --input-path - - ${INPUT_LOCATION} - - --serialization-format - - ${SERIALIZATION_FORMAT} - - --timeout-seconds - - "${BATCH_JOB_TIMEOUT}" - resources: - # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. - requests: - cpu: 1 - memory: 8Gi - limits: - cpu: 4 - memory: 32Gi - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - {{- range $device := tuple "cpu" "gpu" }} - docker-image-batch-job-{{- $device }}.yaml: |- - apiVersion: batch/v1 - kind: Job - metadata: - name: ${RESOURCE_NAME} - labels: - {{- $job_template_labels | nindent 8 }} - spec: - backoffLimit: 0 - activeDeadlineSeconds: ${BATCH_JOB_MAX_RUNTIME} - ttlSecondsAfterFinished: ${BATCH_JOB_TTL_SECONDS_AFTER_FINISHED} - template: - metadata: - labels: - {{- $job_template_labels | nindent 12 }} - sidecar.istio.io/inject: "false" - version: v1 - annotations: - ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:{{ $env }}", "llm_engine_job_id:${JOB_ID}"]}]' - spec: - restartPolicy: Never - nodeSelector: - node-lifecycle: normal - {{- if eq $device "gpu" }} - k8s.amazonaws.com/accelerator: ${GPU_TYPE} - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - {{- end }} - {{- if $service_template_service_account_name }} - serviceAccountName: {{ $service_template_service_account_name }} - {{- else }} - serviceAccountName: {{ $llm_engine_name }} - {{- end }} - volumes: - - name: config-volume - configMap: - name: {{ $aws_config_map_name }} - - name: workdir - emptyDir: {} - - name: dshm - emptyDir: - medium: Memory - containers: - - name: main - image: ${IMAGE} - env: - - name: DD_SERVICE - value: ${RESOURCE_NAME} - {{- $env_vars := $service_env | fromYaml }} - {{- range $env_var := index $env_vars "env" }} - {{- $env_var_name := index $env_var "name" }} - {{- if ne $env_var_name "DD_SERVICE" }} - {{- tuple $env_var | toYaml | nindent 16 }} - {{- end }} - {{- end }} - imagePullPolicy: Always - command: ${COMMAND} - resources: - # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. - requests: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - limits: - {{- if eq $device "gpu" }} - nvidia.com/gpu: ${GPUS} - {{- end }} - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: workdir - mountPath: ${MOUNT_PATH} - - mountPath: /dev/shm - name: dshm - initContainers: - - name: input-downloader - image: {{ $gateway_repository }}:{{ $tag }} - command: - - python - - -m - - server.llm_engine_server.entrypoints.start_docker_image_batch_job_init_container - - ${INPUT_LOCATION} - - --remote-file - - ${S3_FILE} - - --local-file - - ${LOCAL_FILE_NAME} - - --file-contents-b64encoded - - ${FILE_CONTENTS_B64ENCODED} - resources: - requests: - cpu: 1 - memory: 1Gi - limits: - cpu: 1 - memory: 1Gi - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: workdir - mountPath: ${MOUNT_PATH} - {{- end }} - {{- range $device := .Values.imageCache.devices }} - {{- $device_node_selector := index $device "nodeSelector" }} - {{- $device_tolerations := index $device "tolerations" }} - image-cache-{{- index $device "name" }}.yaml: |- - apiVersion: apps/v1 - kind: DaemonSet - metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} - labels: - team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/service: ${RESOURCE_NAME} - spec: - selector: - matchLabels: - app: ${RESOURCE_NAME} - version: v1 - updateStrategy: - type: RollingUpdate - template: - metadata: - labels: - app: ${RESOURCE_NAME} - team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/service: ${RESOURCE_NAME} - version: v1 - sidecar.istio.io/inject: "false" - spec: - {{- if $device_node_selector }} - {{- with $device_node_selector }} - nodeSelector: - {{- toYaml . | nindent 12 }} - {{- end }} - {{- end }} - {{- if $device_tolerations }} - {{- with $device_tolerations }} - tolerations: - {{- toYaml . | nindent 12 }} - {{- end }} - {{- end }} - containers: - - image: public.ecr.aws/docker/library/busybox:latest - imagePullPolicy: IfNotPresent - name: busybox - command: ["/bin/sh", "-ec", "while : ; do sleep 30 ; done"] - terminationGracePeriodSeconds: 0 - {{- end }} diff --git a/charts/llm-engine/values_circleci.yaml b/charts/llm-engine/values_circleci.yaml deleted file mode 100644 index b57e0ec1..00000000 --- a/charts/llm-engine/values_circleci.yaml +++ /dev/null @@ -1,187 +0,0 @@ -# This is a YAML-formatted file. - -replicaCount: - gateway: 1 - cacher: 1 - builder: 1 - balloonA10: 0 - balloonA100: 0 - balloonCpu: 0 - balloonT4: 0 - -# tag needs to be set dynamically every time. Usually it is set to the SHA1 hash of the git -# commit from which the image was built. -# tag: -context: circleci -image: - gatewayRepository: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine - builderRepository: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine - cacherRepository: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine - forwarderRepository: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine - pullPolicy: Always - -# serviceIdentifier: - -secrets: - awsDatabaseSecretName: prod/llm_engine.db - -service: - type: ClusterIP - port: 80 - -virtualservice: - enabled: true - annotations: { } - hostDomains: - - ml-internal.scale.com - gateways: - - default/internal-gateway - -destinationrule: - enabled: true - annotations: { } - -autoscaling: - horizontal: - enabled: false - minReplicas: 1 - maxReplicas: 10 - targetConcurrency: 30 - vertical: - enabled: false - minAllowed: - cpu: 100m - memory: 128Mi - maxAllowed: - cpu: 10 - memory: 8Gi - updateMode: Auto - prewarming: - enabled: false - -resources: - requests: - cpu: 2 - -nodeSelector: { } - -tolerations: [ ] - -affinity: { } - -config: - values: - infra: - k8s_cluster_name: minikube - dns_host_domain: localhost - default_region: us-west-2 - ml_account_id: "000000000000" - docker_repo_prefix: "000000000000.dkr.ecr.us-west-2.amazonaws.com" - redis_host: redis-message-broker-master.default - s3_bucket: "scale-ml-circleci" - profile_ml_worker: "default" - profile_ml_inference_worker: "default" - llm_engine: - # Endpoint config - # K8s namespace the endpoints will be created in - endpoint_namespace: scale-deploy - - # Asynchronous endpoints - sqs_profile: default - sqs_queue_policy_template: > - { - "Version": "2012-10-17", - "Id": "__default_policy_ID", - "Statement": [ - { - "Sid": "__owner_statement", - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:root" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - }, - { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:role/default" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - }, - { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:role/ml_llm_engine" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - } - ] - } - sqs_queue_tag_template: > - { - "Spellbook-Serve-Endpoint-Id": "${endpoint_id}", - "Spellbook-Serve-Endpoint-Name": "${endpoint_name}", - "Spellbook-Serve-Endpoint-Created-By": "${endpoint_created_by}" - } - - cache_redis_url: redis://redis-message-broker-master.default/15 - -# Service Account -serviceAccount: - annotations: - eks.amazonaws.com/role-arn: arn:aws:iam::000000000000:role/eks-default2 - -aws: - configMap: - name: default-config - create: false - profileName: default - -forwarder: - forceUseIPv4: true - -triton: - image: - repository: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv - tag: e83eccbc8959f90ebbe4bda618b61ec6ee2d8394-triton - -serviceTemplate: - securityContext: - capabilities: - drop: - - all - mountInfraConfig: true - serviceAccountName: default - awsConfigMapName: default-config - -imageCache: - devices: - - name: cpu - nodeSelector: - cpu-only: "true" - - name: a10 - nodeSelector: - k8s.amazonaws.com/accelerator: nvidia-ampere-a10 - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - - name: a100 - nodeSelector: - k8s.amazonaws.com/accelerator: nvidia-ampere-a100 - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - - name: t4 - nodeSelector: - k8s.amazonaws.com/accelerator: nvidia-tesla-t4 - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - -celeryBrokerType: redis diff --git a/charts/llm-engine/values_sample.yaml b/charts/llm-engine/values_sample.yaml deleted file mode 100644 index 7b2cbbf0..00000000 --- a/charts/llm-engine/values_sample.yaml +++ /dev/null @@ -1,176 +0,0 @@ -# This is a YAML-formatted file. - -# tag [required] is the LLM Engine docker image tag -tag: 1defd4f9c5376149e27673e154731a0c7820fe5d -# context is a user-specified deployment tag. Can be used to -context: production -image: - # gatewayRepository [required] is the docker repository to pull the LLM Engine gateway image from - gatewayRepository: public.ecr.aws/b2z8n5q1/llm-engine - # builderRepository [required] is the docker repository to pull the LLM Engine endpoint builder image from - builderRepository: public.ecr.aws/b2z8n5q1/llm-engine - # cacherRepository [required] is the docker repository to pull the LLM Engine cacher image from - cacherRepository: public.ecr.aws/b2z8n5q1/llm-engine - # forwarderRepository [required] is the docker repository to pull the LLM Engine forwarder image from - forwarderRepository: public.ecr.aws/b2z8n5q1/llm-engine - # pullPolicy is the docker image pull policy - pullPolicy: Always - -secrets: - # kubernetesDatabaseSecretName [required] is the name of the secret that contains the database credentials - kubernetesDatabaseSecretName: llm-engine-postgres-credentials - -# serviceAccount [required] specifies the service account for LLM Engine server deployments (e.g gateway, cache, and builder deployments). -serviceAccount: - annotations: - # eks.amazonaws.com/role-arn [required] is the ARN of the IAM role that the service account will assume - eks.amazonaws.com/role-arn: arn:aws:iam::000000000000:role/k8s-main-llm-engine - "helm.sh/hook": pre-install,pre-upgrade - "helm.sh/hook-weight": "-2" - namespaces: [] - -# service specifies the service configuration for the main LLM Engine server. Users should setup their own ingress controller to expose the service. -service: - type: ClusterIP - port: 80 - -# replicaCount specifies the amount of replica pods for each deployment -replicaCount: - # gateway is the main LLM Engine server deployment - gateway: 2 - # cacher is the kubernetes state caching deployment - cacher: 1 - # builder is the endpoint builder deployment - builder: 1 - # balloonA10 is a low priority pod deployment for A10 GPU nodes - balloonA10: 0 - # balloonA100 is a low priority pod deployment for A100 GPU nodes - balloonA100: 0 - # balloonCpu is a low priority pod deployment for CPU nodes - balloonCpu: 0 - # balloonT4 is a low priority pod deployment for T4 GPU nodes - balloonT4: 0 - -# autoscaling is the autoscaling configuration for LLM Engine server deployments (e.g gateway, cache, and builder deployments) -autoscaling: - horizontal: - enabled: true - minReplicas: 2 - maxReplicas: 10 - targetConcurrency: 50 - vertical: - enabled: false - prewarming: - enabled: false - -# resources specify the k8s resources for LLM Engine server deployments (e.g gateway, cache, and builder deployments) -resources: - requests: - cpu: 2 -# nodeSelector specifies the node selector for LLM Engine server deployments (e.g gateway, cache, and builder deployments) -nodeSelector: { } -# tolerations specifies the tolerations for LLM Engine server deployments (e.g gateway, cache, and builder deployments) -tolerations: [ ] -# affinity specifies the affinity for LLM Engine server deployments (e.g gateway, cache, and builder deployments) -affinity: { } - -# aws specifies the AWS configurations (by configMap) for LLM Engine server deployments -aws: - configMap: - name: default-config - create: true - profileName: default - -# serviceTemplate specifies additional flags for model endpoints -serviceTemplate: - securityContext: - capabilities: - drop: - - all - mountInfraConfig: true - -# config specifes the `data` field of the service config map -config: - values: - infra: - # k8s_cluster_name [required] is the name of the k8s cluster - k8s_cluster_name: main_cluster - # dns_host_domain [required] is the domain name of the k8s cluster - dns_host_domain: domain.llm-engine.com - # default_region [required] is the default AWS region for various resources (e.g ECR) - default_region: us-east-1 - # aws_account_id [required] is the AWS account ID for various resources (e.g ECR) - ml_account_id: "000000000000" - # docker_repo_prefix [required] is the prefix for AWS ECR repositories - docker_repo_prefix: "000000000000.dkr.ecr.us-east-1.amazonaws.com" - # redis_host [required] is the hostname of the redis cluster you wish to connect - redis_host: llm-engine-prod-cache.use1.cache.amazonaws.com - # s3_bucket [required] is the S3 bucket you wish to connect - s3_bucket: "llm-engine" - llm_engine: - # endpoint_namespace [required] is K8s namespace the endpoints will be created in - endpoint_namespace: llm-engine - # cache_redis_url [required] is the full url for the redis cluster you wish to connect - cache_redis_url: redis://llm-engine-prod-cache.use1.cache.amazonaws.com:6379/15 - # s3_file_llm_fine_tuning_job_repository [required] is the S3 URI for the S3 bucket/key that you wish to save fine-tuned assests - s3_file_llm_fine_tuning_job_repository: "s3://llm-engine/llm-ft-job-repository" - # datadog_trace_enabled specifies whether to enable datadog tracing, datadog must be installed in the cluster - datadog_trace_enabled: false - - # Asynchronous endpoints configs (coming soon) - sqs_profile: default - # sqs_queue_policy_template [required] is the IAM policy template for SQS queue for async endpoints. - sqs_queue_policy_template: > - { - "Version": "2012-10-17", - "Id": "__default_policy_ID", - "Statement": [ - { - "Sid": "__owner_statement", - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:root" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-east-1:000000000000:${queue_name}" - }, - { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:role/k8s-main-llm-engine" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-east-1:000000000000:${queue_name}" - } - ] - } - - sqs_queue_tag_template: > - { - "Spellbook-Serve-Endpoint-Id": "${endpoint_id}", - "Spellbook-Serve-Endpoint-Name": "${endpoint_name}", - "Spellbook-Serve-Endpoint-Created-By": "${endpoint_created_by}" - } - -# Triton enhanced endpoints (coming soon) -triton: - image: - repository: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv - tag: e83eccbc8959f90ebbe4bda618b61ec6ee2d8394-triton - -# imageCache specifies the image cache configuration for faster endpoint auto-scaling (coming soon) -imageCache: - devices: - - name: cpu - nodeSelector: - cpu-only: "true" - - name: a10 - nodeSelector: - k8s.amazonaws.com/accelerator: nvidia-ampere-a10 - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - -# celeryBrokerType specifies the celery broker type for async endpoints (coming soon) -celeryBrokerType: sqs diff --git a/charts/llm-engine/.helmignore b/charts/model-engine/.helmignore similarity index 100% rename from charts/llm-engine/.helmignore rename to charts/model-engine/.helmignore diff --git a/charts/llm-engine/Chart.yaml b/charts/model-engine/Chart.yaml similarity index 97% rename from charts/llm-engine/Chart.yaml rename to charts/model-engine/Chart.yaml index 40300d18..e9ef0518 100644 --- a/charts/llm-engine/Chart.yaml +++ b/charts/model-engine/Chart.yaml @@ -1,5 +1,5 @@ apiVersion: v2 -name: llm-engine +name: model-engine description: A Helm chart for Kubernetes # A chart can be either an 'application' or a 'library' chart. @@ -15,7 +15,7 @@ type: application # This is the chart version. This version number should be incremented each time you make changes # to the chart and its templates, including the app version. # Versions are expected to follow Semantic Versioning (https://semver.org/) -version: 0.1.0 +version: 0.1.5 # This is the version number of the application being deployed. This version number should be # incremented each time you make changes to the application. Versions are not expected to diff --git a/charts/model-engine/README.md b/charts/model-engine/README.md new file mode 100644 index 00000000..19c826ce --- /dev/null +++ b/charts/model-engine/README.md @@ -0,0 +1,3 @@ +# Scale Launch Helm Chart + +This chart contains k8s templates for the gateway, endpoint builder, and k8s cacher. diff --git a/charts/model-engine/templates/_helpers.tpl b/charts/model-engine/templates/_helpers.tpl new file mode 100644 index 00000000..50383770 --- /dev/null +++ b/charts/model-engine/templates/_helpers.tpl @@ -0,0 +1,478 @@ +{{/* +Expand the name of the chart. +*/}} +{{- define "modelEngine.name" -}} +{{- default .Chart.Name | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{/* +Create a default fully qualified app name. +We truncate at 40 chars because some Kubernetes name fields are limited to 63 (by the DNS naming spec). +If release name contains chart name it will be used as a full name. +*/}} +{{- define "modelEngine.fullname" -}} +{{- if .Values.serviceIdentifier }} +{{- printf "%s-%s" .Chart.Name .Values.serviceIdentifier | trunc 40 | trimSuffix "-" }} +{{- else }} +{{- default .Chart.Name | trunc 40 | trimSuffix "-" }} +{{- end }} +{{- end }} + +{{- define "modelEngine.buildername" -}} +"{{ include "modelEngine.fullname" . }}-endpoint-builder" +{{- end }} + +{{- define "modelEngine.cachername" -}} +"{{ include "modelEngine.fullname" . }}-cacher" +{{- end }} + +{{- define "modelEngine.gatewayurl" -}} +{{ .Values.hostDomain.prefix }}{{ include "modelEngine.fullname" . }}.{{ .Release.Namespace }}:{{ .Values.service.port }} +{{- end }} + +{{- define "modelEngine.celeryautoscalername" -}} +{{- if .Values.serviceIdentifier }} +{{- printf "celery-autoscaler-%s-%s" .Values.celeryBrokerType .Values.serviceIdentifier }} +{{- else }} +{{- printf "celery-autoscaler-%s" .Values.celeryBrokerType }} +{{- end }} +{{- end }} + +{{/* +Create chart name and version as used by the chart label. +*/}} +{{- define "modelEngine.chart" -}} +{{- printf "%s-%s" .Chart.Name .Chart.Version | replace "+" "_" | trunc 63 | trimSuffix "-" }} +{{- end }} + +{{- define "modelEngine.baseLabels" -}} +team: infra +app.kubernetes.io/version: {{ .Values.tag }} +tags.datadoghq.com/version: {{ .Values.tag }} +tags.datadoghq.com/env: {{ .Values.context }} +env: {{ .Values.context }} +{{- if .Values.azure }} +azure.workload.identity/use: "true" +{{- end }} +{{- end }} + +{{/* +Common labels +*/}} +{{- define "modelEngine.labels" -}} +{{- include "modelEngine.baseLabels" . | printf "%s\n" -}} +product: model-engine +helm.sh/chart: {{ include "modelEngine.chart" . }} +app.kubernetes.io/managed-by: {{ .Release.Service }} +{{- end }} + +{{- define "modelEngine.selectorLabels.builder" -}} +app: {{ include "modelEngine.buildername" . }} +{{- end }} + +{{- define "modelEngine.selectorLabels.cacher" -}} +app: {{ include "modelEngine.cachername" . }} +{{- end }} + +{{- define "modelEngine.selectorLabels.gateway" -}} +app: {{ include "modelEngine.fullname" . -}} +{{- end }} + +{{- define "modelEngine.selectorLabels.celeryAutoscaler" -}} +app: {{ include "modelEngine.celeryautoscalername" . }} +product: common +tags.datadoghq.com/service: {{ include "modelEngine.celeryautoscalername" . -}} +{{- end }} + +{{- define "modelEngine.baseTemplateLabels" -}} +user_id: ${OWNER} +team: ${TEAM} +product: ${PRODUCT} +created_by: ${CREATED_BY} +owner: ${OWNER} +env: {{- .Values.context | printf " %s" }} +managed-by: {{- include "modelEngine.fullname" . | printf " %s\n" -}} +use_scale_launch_endpoint_network_policy: "true" +tags.datadoghq.com/env: {{- .Values.context | printf " %s" }} +tags.datadoghq.com/version: ${GIT_TAG} +{{- if .Values.azure }} +azure.workload.identity/use: "true" +{{- end }} +{{- end }} + +{{- define "modelEngine.serviceTemplateLabels" -}} +{{- include "modelEngine.baseTemplateLabels" . | printf "%s\n" -}} +tags.datadoghq.com/service: ${ENDPOINT_NAME} +endpoint_id: ${ENDPOINT_ID} +endpoint_name: ${ENDPOINT_NAME} +{{- end }} + +{{- define "modelEngine.jobTemplateLabels" -}} +{{- include "modelEngine.baseTemplateLabels" . | printf "%s\n" -}} +launch_job_id: ${JOB_ID} +tags.datadoghq.com/request_id: ${REQUEST_ID} +tags.datadoghq.com/service: ${JOB_ID} +tags.datadoghq.com/user_id: ${OWNER} +tags.datadoghq.com/team: ${TEAM} +{{- end }} + +{{- define "modelEngine.serviceTemplateAsyncAnnotations" -}} +celery.scaleml.autoscaler/queue: ${QUEUE} +celery.scaleml.autoscaler/broker: ${BROKER_NAME} +celery.scaleml.autoscaler/taskVisibility: "VISIBILITY_24H" +celery.scaleml.autoscaler/perWorker: "${PER_WORKER}" +celery.scaleml.autoscaler/minWorkers: "${MIN_WORKERS}" +celery.scaleml.autoscaler/maxWorkers: "${MAX_WORKERS}" +{{- end }} + +{{- define "modelEngine.serviceTemplateAffinity" -}} +podAffinity: + preferredDuringSchedulingIgnoredDuringExecution: + - weight: 1 + podAffinityTerm: + labelSelector: + matchExpressions: + - key: app + operator: In + values: + - ${RESOURCE_NAME} + topologyKey: kubernetes.io/hostname + - weight: 100 + podAffinityTerm: + labelSelector: + matchExpressions: + - key: ${IMAGE_HASH} + operator: In + values: + - "True" + topologyKey: kubernetes.io/hostname +{{- end }} + +{{- define "modelEngine.baseServiceTemplateEnv" -}} +env: + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" + - name: DD_SERVICE + value: "${ENDPOINT_NAME}" + - name: DD_ENV + value: {{ .Values.context }} + - name: DD_VERSION + value: "${GIT_TAG}" + - name: DD_AGENT_HOST + valueFrom: + fieldRef: + fieldPath: status.hostIP + - name: OMP_NUM_THREADS + value: "1" + - name: BASE_PATH + value: "${BASE_PATH}" + - name: BUNDLE_URL + value: "${BUNDLE_URL}" + - name: LOAD_PREDICT_FN_MODULE_PATH + value: "${LOAD_PREDICT_FN_MODULE_PATH}" + - name: LOAD_MODEL_FN_MODULE_PATH + value: "${LOAD_MODEL_FN_MODULE_PATH}" + {{- if .Values.aws }} + - name: AWS_PROFILE + value: "${AWS_ROLE}" + {{- end }} + - name: RESULTS_S3_BUCKET + value: "${RESULTS_S3_BUCKET}" + - name: CHILD_FN_INFO + value: "${CHILD_FN_INFO}" + - name: PREWARM + value: "${PREWARM}" + - name: ML_INFRA_SERVICES_CONFIG_PATH + {{- if .Values.config.file }} + value: {{ .Values.config.file.infra | quote }} + {{- else }} + value: "${BASE_PATH}/model-engine/model_engine_server/core/configs/config.yaml" + {{- end }} +{{- end }} + +{{- define "modelEngine.syncServiceTemplateEnv" -}} +{{- include "modelEngine.baseServiceTemplateEnv" . }} + - name: PORT + value: "${ARTIFACT_LIKE_CONTAINER_PORT}" +{{- end }} + +{{- define "modelEngine.asyncServiceTemplateEnv" -}} +{{- include "modelEngine.baseServiceTemplateEnv" . }} + - name: CELERY_S3_BUCKET + value: "${CELERY_S3_BUCKET}" + - name: BROKER_TYPE + value: "${BROKER_TYPE}" + - name: SQS_PROFILE + value: "${SQS_PROFILE}" + - name: SQS_QUEUE_NAME + value: "${QUEUE}" + - name: SQS_QUEUE_URL + value: "${SQS_QUEUE_URL}" +{{- end }} + +{{- define "modelEngine.baseForwarderTemplateEnv" -}} +env: + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" + - name: DD_SERVICE + value: "${ENDPOINT_NAME}" + - name: DD_ENV + value: {{ .Values.context }} + - name: DD_VERSION + value: "${GIT_TAG}" + - name: DD_AGENT_HOST + valueFrom: + fieldRef: + fieldPath: status.hostIP + {{- if .Values.aws }} + - name: AWS_PROFILE + value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config + {{- end }} + - name: RESULTS_S3_BUCKET + value: "${RESULTS_S3_BUCKET}" + - name: BASE_PATH + value: "/workspace" + - name: ML_INFRA_SERVICES_CONFIG_PATH + {{- if .Values.config.file }} + value: {{ .Values.config.file.infra | quote }} + {{- else }} + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" + {{- end }} + {{- if .Values.azure}} + - name: AZURE_IDENTITY_NAME + value: {{ .Values.azure.identity_name }} + - name: AZURE_CLIENT_ID + value: {{ .Values.azure.client_id }} + - name: AZURE_OBJECT_ID + value: {{ .Values.azure.object_id }} + - name: ABS_ACCOUNT_NAME + value: {{ .Values.azure.abs_account_name }} + - name: ABS_CONTAINER_NAME + value: {{ .Values.azure.abs_container_name }} + {{- end }} +{{- end }} + +{{- define "modelEngine.syncForwarderTemplateEnv" -}} +{{- include "modelEngine.baseForwarderTemplateEnv" . }} +{{- if and .Values.forwarder .Values.forwarder.forceUseIPv4 }} + - name: HTTP_HOST + value: "0.0.0.0" +{{- end }} +{{- end }} + +{{- define "modelEngine.asyncForwarderTemplateEnv" -}} +{{- include "modelEngine.baseForwarderTemplateEnv" . }} + - name: CELERY_QUEUE + value: "${QUEUE}" + - name: CELERY_TASK_VISIBILITY + value: "VISIBILITY_24H" + - name: S3_BUCKET + value: "${CELERY_S3_BUCKET}" + {{- if .Values.azure}} + - name: ABS_ACCOUNT_NAME + value: {{ .Values.azure.abs_account_name }} + - name: ABS_CONTAINER_NAME + value: {{ .Values.azure.abs_container_name }} + - name: SERVICEBUS_NAMESPACE + value: {{ .Values.azure.servicebus_namespace }} + {{- end }} +{{- end }} + +{{- define "modelEngine.serviceEnvBase" }} +env: + - name: DD_TRACE_ENABLED + value: "{{ .Values.dd_trace_enabled }}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" + - name: DD_ENV + value: {{ .Values.context }} + - name: DD_AGENT_HOST + valueFrom: + fieldRef: + fieldPath: status.hostIP + - name: SERVICE_IDENTIFIER + {{- if .Values.serviceIdentifier }} + value: {{ .Values.serviceIdentifier }} + {{- end }} + - name: GATEWAY_URL + value: {{ include "modelEngine.gatewayurl" . }} + {{- if .Values.aws }} + - name: AWS_PROFILE + value: {{ .Values.aws.profileName }} + - name: AWS_CONFIG_FILE + value: /opt/.aws/config + - name: ECR_READ_AWS_PROFILE + value: {{ .Values.aws.profileName }} + - name: DB_SECRET_AWS_PROFILE + value: {{ .Values.aws.profileName }} + - name: S3_WRITE_AWS_PROFILE + value: {{ .Values.aws.s3WriteProfileName }} + {{- end }} + {{- with .Values.secrets }} + {{- if .kubernetesDatabaseSecretName }} + - name: ML_INFRA_DATABASE_URL + valueFrom: + secretKeyRef: + name: {{ .kubernetesDatabaseSecretName }} + key: database_url + {{- else if .cloudDatabaseSecretName }} + - name: DB_SECRET_NAME + value: {{ .cloudDatabaseSecretName }} + {{- end }} + {{- end }} + {{- if .Values.config.file }} + - name: DEPLOY_SERVICE_CONFIG_PATH + value: {{ .Values.config.file.launch | quote }} + - name: ML_INFRA_SERVICES_CONFIG_PATH + value: {{ .Values.config.file.infra | quote }} + {{- else }} + - name: DEPLOY_SERVICE_CONFIG_PATH + value: "/workspace/model-engine/service_configs/service_config.yaml" + - name: ML_INFRA_SERVICES_CONFIG_PATH + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" + {{- end }} + - name: CELERY_ELASTICACHE_ENABLED + value: "true" + - name: LAUNCH_SERVICE_TEMPLATE_FOLDER + value: "/workspace/model-engine/model_engine_server/infra/gateways/resources/templates" + {{- if .Values.redis.auth}} + - name: REDIS_AUTH_TOKEN + value: {{ .Values.redis.auth }} + {{- end }} + {{- if .Values.azure}} + - name: AZURE_IDENTITY_NAME + value: {{ .Values.azure.identity_name }} + - name: AZURE_CLIENT_ID + value: {{ .Values.azure.client_id }} + - name: AZURE_OBJECT_ID + value: {{ .Values.azure.object_id }} + - name: KEYVAULT_NAME + value: {{ .Values.azure.keyvault_name }} + - name: ABS_ACCOUNT_NAME + value: {{ .Values.azure.abs_account_name }} + - name: ABS_CONTAINER_NAME + value: {{ .Values.azure.abs_container_name }} + - name: SERVICEBUS_NAMESPACE + value: {{ .Values.azure.servicebus_namespace }} + {{- end }} + {{- if eq .Values.context "circleci" }} + - name: CIRCLECI + value: "true" + {{- end }} +{{- end }} + +{{- define "modelEngine.serviceEnvGitTagFromHelmVar" }} +{{- include "modelEngine.serviceEnvBase" . }} + - name: DD_VERSION + value: {{ .Values.tag }} + - name: GIT_TAG + value: {{ .Values.tag }} +{{- end }} + +{{- define "modelEngine.serviceEnvGitTagFromPythonReplace" }} +{{- include "modelEngine.serviceEnvBase" . }} + - name: DD_VERSION + value: "${GIT_TAG}" + - name: GIT_TAG + value: "${GIT_TAG}" +{{- end }} + + +{{- define "modelEngine.gatewayEnv" }} +{{- include "modelEngine.serviceEnvGitTagFromHelmVar" . }} + - name: DD_SERVICE + value: {{- printf " %s" (include "modelEngine.fullname" .) }} +{{- end }} + +{{- define "modelEngine.builderEnv" }} +{{- include "modelEngine.serviceEnvGitTagFromHelmVar" . }} + - name: DD_SERVICE + value: {{- printf " %s" (include "modelEngine.buildername" .) }} +{{- end }} + +{{- define "modelEngine.cacherEnv" }} +{{- include "modelEngine.serviceEnvGitTagFromHelmVar" . }} + - name: DD_SERVICE + value: {{- printf " %s" (include "modelEngine.cachername" .) }} +{{- end }} + +{{- define "modelEngine.volumes" }} +volumes: + - name: dshm + emptyDir: + medium: Memory + - name: service-template-config + configMap: + name: {{ include "modelEngine.fullname" . }}-service-template-config + {{- if .Values.aws }} + - name: config-volume + configMap: + name: {{ .Values.aws.configMap.name }} + {{- end }} + {{- if .Values.config.values }} + - name: {{ .Chart.Name }}-service-config-volume + configMap: + name: {{ include "modelEngine.fullname" . }}-service-config + items: + - key: launch_service_config + path: service_config.yaml + - name: infra-service-config-volume + configMap: + name: {{ include "modelEngine.fullname" . }}-service-config + items: + - key: infra_service_config + path: config.yaml + {{- end }} +{{- end }} + +{{- define "modelEngine.volumeMounts" }} +volumeMounts: + - name: dshm + mountPath: /dev/shm + - name: service-template-config + mountPath: /workspace/model-engine/model_engine_server/infra/gateways/resources/templates + {{- if .Values.aws }} + - name: config-volume + mountPath: {{ .Values.aws.configMap.mountPath }} + subPath: config + {{- end }} + {{- if .Values.config.values }} + - name: {{ .Chart.Name }}-service-config-volume + mountPath: /workspace/model-engine/service_configs + - name: infra-service-config-volume + mountPath: /workspace/model-engine/model_engine_server/core/configs + {{- end }} +{{- end }} + +{{- define "modelEngine.forwarderVolumeMounts" }} +volumeMounts: + {{- if .Values.aws }} + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + {{- end }} + - name: user-config + mountPath: /workspace/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /workspace/endpoint_config + subPath: raw_data + {{- if .Values.config.values }} + - name: infra-service-config-volume + mountPath: /workspace/model-engine/model_engine_server/core/configs + {{- end }} +{{- end }} + +{{- define "modelEngine.serviceAccountNamespaces" }} +namespaces: + - {{ .Release.Namespace }} +{{- range .Values.serviceAccount.namespaces }} + - {{ . }} +{{- end }} +{{- end }} diff --git a/charts/model-engine/templates/_istio-attribute-match-conditions.tpl b/charts/model-engine/templates/_istio-attribute-match-conditions.tpl new file mode 100644 index 00000000..6e9feeb1 --- /dev/null +++ b/charts/model-engine/templates/_istio-attribute-match-conditions.tpl @@ -0,0 +1,117 @@ +{{- /* Generated from the OpenAPI schema with model-engine-internal/scripts/generate_istio_metric_tags.py */}} +{{- define "modelEngine.istioAttributeMatchConditions" -}} +- condition: request.method == 'GET' && request.url_path == '/healthcheck' + value: get_/healthcheck +- condition: request.method == 'GET' && request.url_path == '/healthz' + value: get_/healthz +- condition: request.method == 'GET' && request.url_path == '/readyz' + value: get_/readyz +- condition: request.method == 'POST' && request.url_path == '/v1/async-tasks' + value: post_/v1/async-tasks +- condition: request.method == 'GET' && request.url_path.matches('^/v1/async-tasks/[[:alnum:]-_]*$') + value: get_/v1/async-tasks/_task_id +- condition: request.method == 'POST' && request.url_path == '/v1/batch-jobs' + value: post_/v1/batch-jobs +- condition: request.method == 'GET' && request.url_path.matches('^/v1/batch-jobs/[[:alnum:]-_]*$') + value: get_/v1/batch-jobs/_batch_job_id +- condition: request.method == 'PUT' && request.url_path.matches('^/v1/batch-jobs/[[:alnum:]-_]*$') + value: put_/v1/batch-jobs/_batch_job_id +- condition: request.method == 'GET' && request.url_path == '/v1/docker-image-batch-job-bundles' + value: get_/v1/docker-image-batch-job-bundles +- condition: request.method == 'POST' && request.url_path == '/v1/docker-image-batch-job-bundles' + value: post_/v1/docker-image-batch-job-bundles +- condition: request.method == 'GET' && request.url_path == '/v1/docker-image-batch-job-bundles/latest' + value: get_/v1/docker-image-batch-job-bundles/latest +- condition: request.method == 'GET' && request.url_path.matches('^/v1/docker-image-batch-job-bundles/[[:alnum:]-_]*$') + value: get_/v1/docker-image-batch-job-bundles/_docker_image_batch_job_bundle_id +- condition: request.method == 'GET' && request.url_path == '/v1/docker-image-batch-jobs' + value: get_/v1/docker-image-batch-jobs +- condition: request.method == 'POST' && request.url_path == '/v1/docker-image-batch-jobs' + value: post_/v1/docker-image-batch-jobs +- condition: request.method == 'GET' && request.url_path.matches('^/v1/docker-image-batch-jobs/[[:alnum:]-_]*$') + value: get_/v1/docker-image-batch-jobs/_batch_job_id +- condition: request.method == 'PUT' && request.url_path.matches('^/v1/docker-image-batch-jobs/[[:alnum:]-_]*$') + value: put_/v1/docker-image-batch-jobs/_batch_job_id +- condition: request.method == 'GET' && request.url_path == '/v1/files' + value: get_/v1/files +- condition: request.method == 'POST' && request.url_path == '/v1/files' + value: post_/v1/files +- condition: request.method == 'DELETE' && request.url_path.matches('^/v1/files/[[:alnum:]-_]*$') + value: delete_/v1/files/_file_id +- condition: request.method == 'GET' && request.url_path.matches('^/v1/files/[[:alnum:]-_]*$') + value: get_/v1/files/_file_id +- condition: request.method == 'GET' && request.url_path.matches('^/v1/files/[[:alnum:]-_]*/content$') + value: get_/v1/files/_file_id/content +- condition: request.method == 'POST' && request.url_path == '/v1/llm/completions-stream' + value: post_/v1/llm/completions-stream +- condition: request.method == 'POST' && request.url_path == '/v1/llm/completions-sync' + value: post_/v1/llm/completions-sync +- condition: request.method == 'GET' && request.url_path == '/v1/llm/fine-tunes' + value: get_/v1/llm/fine-tunes +- condition: request.method == 'POST' && request.url_path == '/v1/llm/fine-tunes' + value: post_/v1/llm/fine-tunes +- condition: request.method == 'GET' && request.url_path.matches('^/v1/llm/fine-tunes/[[:alnum:]-_]*$') + value: get_/v1/llm/fine-tunes/_fine_tune_id +- condition: request.method == 'PUT' && request.url_path.matches('^/v1/llm/fine-tunes/[[:alnum:]-_]*/cancel$') + value: put_/v1/llm/fine-tunes/_fine_tune_id/cancel +- condition: request.method == 'GET' && request.url_path.matches('^/v1/llm/fine-tunes/[[:alnum:]-_]*/events$') + value: get_/v1/llm/fine-tunes/_fine_tune_id/events +- condition: request.method == 'GET' && request.url_path == '/v1/llm/model-endpoints' + value: get_/v1/llm/model-endpoints +- condition: request.method == 'POST' && request.url_path == '/v1/llm/model-endpoints' + value: post_/v1/llm/model-endpoints +- condition: request.method == 'POST' && request.url_path == '/v1/llm/model-endpoints/download' + value: post_/v1/llm/model-endpoints/download +- condition: request.method == 'DELETE' && request.url_path.matches('^/v1/llm/model-endpoints/[[:alnum:]-_]*$') + value: delete_/v1/llm/model-endpoints/_model_endpoint_name +- condition: request.method == 'GET' && request.url_path.matches('^/v1/llm/model-endpoints/[[:alnum:]-_]*$') + value: get_/v1/llm/model-endpoints/_model_endpoint_name +- condition: request.method == 'GET' && request.url_path == '/v1/model-bundles' + value: get_/v1/model-bundles +- condition: request.method == 'POST' && request.url_path == '/v1/model-bundles' + value: post_/v1/model-bundles +- condition: request.method == 'POST' && request.url_path == '/v1/model-bundles/clone-with-changes' + value: post_/v1/model-bundles/clone-with-changes +- condition: request.method == 'GET' && request.url_path == '/v1/model-bundles/latest' + value: get_/v1/model-bundles/latest +- condition: request.method == 'GET' && request.url_path.matches('^/v1/model-bundles/[[:alnum:]-_]*$') + value: get_/v1/model-bundles/_model_bundle_id +- condition: request.method == 'GET' && request.url_path == '/v1/model-endpoints' + value: get_/v1/model-endpoints +- condition: request.method == 'POST' && request.url_path == '/v1/model-endpoints' + value: post_/v1/model-endpoints +- condition: request.method == 'GET' && request.url_path == '/v1/model-endpoints-api' + value: get_/v1/model-endpoints-api +- condition: request.method == 'GET' && request.url_path == '/v1/model-endpoints-schema.json' + value: get_/v1/model-endpoints-schema.json +- condition: request.method == 'DELETE' && request.url_path.matches('^/v1/model-endpoints/[[:alnum:]-_]*$') + value: delete_/v1/model-endpoints/_model_endpoint_id +- condition: request.method == 'GET' && request.url_path.matches('^/v1/model-endpoints/[[:alnum:]-_]*$') + value: get_/v1/model-endpoints/_model_endpoint_id +- condition: request.method == 'PUT' && request.url_path.matches('^/v1/model-endpoints/[[:alnum:]-_]*$') + value: put_/v1/model-endpoints/_model_endpoint_id +- condition: request.method == 'POST' && request.url_path == '/v1/streaming-tasks' + value: post_/v1/streaming-tasks +- condition: request.method == 'POST' && request.url_path == '/v1/sync-tasks' + value: post_/v1/sync-tasks +- condition: request.method == 'GET' && request.url_path == '/v1/triggers' + value: get_/v1/triggers +- condition: request.method == 'POST' && request.url_path == '/v1/triggers' + value: post_/v1/triggers +- condition: request.method == 'DELETE' && request.url_path.matches('^/v1/triggers/[[:alnum:]-_]*$') + value: delete_/v1/triggers/_trigger_id +- condition: request.method == 'GET' && request.url_path.matches('^/v1/triggers/[[:alnum:]-_]*$') + value: get_/v1/triggers/_trigger_id +- condition: request.method == 'PUT' && request.url_path.matches('^/v1/triggers/[[:alnum:]-_]*$') + value: put_/v1/triggers/_trigger_id +- condition: request.method == 'GET' && request.url_path == '/v2/model-bundles' + value: get_/v2/model-bundles +- condition: request.method == 'POST' && request.url_path == '/v2/model-bundles' + value: post_/v2/model-bundles +- condition: request.method == 'POST' && request.url_path == '/v2/model-bundles/clone-with-changes' + value: post_/v2/model-bundles/clone-with-changes +- condition: request.method == 'GET' && request.url_path == '/v2/model-bundles/latest' + value: get_/v2/model-bundles/latest +- condition: request.method == 'GET' && request.url_path.matches('^/v2/model-bundles/[[:alnum:]-_]*$') + value: get_/v2/model-bundles/_model_bundle_id +{{- end -}} diff --git a/charts/model-engine/templates/aws_config_map.yaml b/charts/model-engine/templates/aws_config_map.yaml new file mode 100644 index 00000000..60b91c97 --- /dev/null +++ b/charts/model-engine/templates/aws_config_map.yaml @@ -0,0 +1,26 @@ +{{- if .Values.aws }} +{{- if eq .Values.aws.configMap.create true }} +{{- $name := .Values.aws.configMap.name }} +{{- $profileName := .Values.aws.profileName }} +{{- $annotations := .Values.serviceAccount.annotations }} +{{- $labels := include "modelEngine.labels" . }} +{{- range $namespace := .Values.aws.configMap.namespaces }} +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ $name }} + namespace: {{- printf " %s" $namespace }} + labels: + {{- $labels | nindent 4 }} + annotations: + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-2" +data: + config: |- + [profile {{ $profileName }}] + role_arn = {{ index $annotations "eks.amazonaws.com/role-arn" }} + web_identity_token_file = /var/run/secrets/eks.amazonaws.com/serviceaccount/token +--- +{{- end }} +{{- end }} +{{- end }} diff --git a/charts/llm-engine/templates/balloon_cpu_deployment.yaml b/charts/model-engine/templates/balloon_cpu_deployment.yaml similarity index 62% rename from charts/llm-engine/templates/balloon_cpu_deployment.yaml rename to charts/model-engine/templates/balloon_cpu_deployment.yaml index 6849bc61..583e3c1e 100644 --- a/charts/llm-engine/templates/balloon_cpu_deployment.yaml +++ b/charts/model-engine/templates/balloon_cpu_deployment.yaml @@ -1,30 +1,34 @@ {{- if not .Values.serviceIdentifier }} +{{- range .Values.balloons }} +{{- if eq .acceleratorName "cpu" }} apiVersion: apps/v1 kind: Deployment metadata: - name: llm-engine-balloon-cpu + name: {{ $.Chart.Name }}-balloon-cpu labels: team: infra product: common-warm-nodes spec: - replicas: {{ .Values.replicaCount.balloonCpu }} + replicas: {{ .replicaCount }} selector: matchLabels: - app: llm-engine-balloon-cpu + app: {{ $.Chart.Name }}-balloon-cpu version: v1 template: metadata: labels: - app: llm-engine-balloon-cpu + app: {{ $.Chart.Name }}-balloon-cpu product: common-warm-nodes team: infra - env: {{ .Values.context }} + env: {{ $.Values.context }} version: v1 annotations: sidecar.istio.io/inject: "false" spec: + {{- with $.Values.balloonNodeSelector }} nodeSelector: - node-lifecycle: normal + {{- toYaml . | nindent 8 }} + {{- end }} containers: - image: public.ecr.aws/ubuntu/ubuntu:latest imagePullPolicy: IfNotPresent @@ -32,11 +36,13 @@ spec: resources: limits: memory: 28Gi - cpu: 8 + cpu: 6 command: - /bin/bash - -c - "while true; do sleep 30; done" terminationGracePeriodSeconds: 0 - priorityClassName: llm-engine-low-priority + priorityClassName: {{ $.Chart.Name }}-low-priority +{{- end }} +{{- end }} {{- end }} diff --git a/charts/llm-engine/templates/balloon_a10_deployment.yaml b/charts/model-engine/templates/balloon_deployments.yaml similarity index 58% rename from charts/llm-engine/templates/balloon_a10_deployment.yaml rename to charts/model-engine/templates/balloon_deployments.yaml index 183392f1..49a1890f 100644 --- a/charts/llm-engine/templates/balloon_a10_deployment.yaml +++ b/charts/model-engine/templates/balloon_deployments.yaml @@ -1,31 +1,35 @@ {{- if not .Values.serviceIdentifier }} +{{- range .Values.balloons }} +{{- if not (eq .acceleratorName "cpu") }} apiVersion: apps/v1 kind: Deployment metadata: - name: llm-engine-balloon-a10 + name: {{ $.Chart.Name }}-balloon-{{ .acceleratorName }} labels: team: infra product: common-warm-nodes spec: - replicas: {{ .Values.replicaCount.balloonA10 }} + replicas: {{ .replicaCount }} selector: matchLabels: - app: llm-engine-balloon-a10 + app: {{ $.Chart.Name }}-balloon-{{ .acceleratorName }} version: v1 template: metadata: labels: - app: llm-engine-balloon-a10 + app: {{ $.Chart.Name }}-balloon-{{ .acceleratorName }} product: common-warm-nodes team: infra - env: {{ .Values.context }} + env: {{ $.Values.context }} version: v1 annotations: sidecar.istio.io/inject: "false" spec: nodeSelector: - k8s.amazonaws.com/accelerator: nvidia-ampere-a10 - node-lifecycle: normal + k8s.amazonaws.com/accelerator: {{ .acceleratorName }} + {{- with $.Values.balloonNodeSelector }} + {{- toYaml . | nindent 8 }} + {{- end }} tolerations: - key: "nvidia.com/gpu" operator: "Exists" @@ -37,12 +41,15 @@ spec: resources: limits: memory: 28Gi - nvidia.com/gpu: 1 + nvidia.com/gpu: {{ .gpuCount | default 1 }} cpu: 4 command: - /bin/bash - -c - "while true; do sleep 30; done" terminationGracePeriodSeconds: 0 - priorityClassName: llm-engine-low-priority + priorityClassName: {{ $.Chart.Name }}-low-priority +--- +{{- end }} +{{- end }} {{- end }} diff --git a/charts/llm-engine/templates/cacher_deployment.yaml b/charts/model-engine/templates/cacher_deployment.yaml similarity index 57% rename from charts/llm-engine/templates/cacher_deployment.yaml rename to charts/model-engine/templates/cacher_deployment.yaml index 6c833c35..09297aba 100644 --- a/charts/llm-engine/templates/cacher_deployment.yaml +++ b/charts/model-engine/templates/cacher_deployment.yaml @@ -1,28 +1,28 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: {{ include "llmEngine.cachername" . }} + name: {{ include "modelEngine.cachername" . }} labels: - {{- include "llmEngine.selectorLabels.cacher" . | nindent 4 }} - {{- include "llmEngine.labels" . | nindent 4 }} - tags.datadoghq.com/service: {{ include "llmEngine.cachername" . }} + {{- include "modelEngine.selectorLabels.cacher" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} + tags.datadoghq.com/service: {{ include "modelEngine.cachername" . }} spec: replicas: {{ .Values.replicaCount.cacher }} selector: matchLabels: - {{- include "llmEngine.selectorLabels.cacher" . | nindent 6 }} + {{- include "modelEngine.selectorLabels.cacher" . | nindent 6 }} template: metadata: annotations: ad.datadoghq.com/main.logs: | [{ - "service": {{ include "llmEngine.cachername" . | quote }}, + "service": {{ include "modelEngine.cachername" . | quote }}, "source": "python" }] labels: - {{- include "llmEngine.selectorLabels.cacher" . | nindent 8 }} - {{- include "llmEngine.labels" . | nindent 8 }} - tags.datadoghq.com/service: {{ include "llmEngine.cachername" . }} + {{- include "modelEngine.selectorLabels.cacher" . | nindent 8 }} + {{- include "modelEngine.labels" . | nindent 8 }} + tags.datadoghq.com/service: {{ include "modelEngine.cachername" . }} sidecar.istio.io/inject: "false" spec: {{- with .Values.imagePullSecrets }} @@ -30,7 +30,7 @@ spec: {{- toYaml . | nindent 8 }} {{- end }} containers: - - name: {{ include "llmEngine.cachername" . }} + - name: {{ include "modelEngine.cachername" . }} image: "{{ .Values.image.gatewayRepository }}:{{ .Values.tag}}" imagePullPolicy: {{ .Values.image.pullPolicy }} ports: @@ -45,17 +45,19 @@ spec: command: - dumb-init - -- + {{- if .Values.datadog.enabled }} - ddtrace-run + {{- end }} args: - python - -m - - server.llm_engine_server.entrypoints.k8s_cache + - model_engine_server.entrypoints.k8s_cache resources: {{- toYaml .Values.resources | nindent 12 }} - {{- include "llmEngine.cacherEnv" . | indent 10 }} - {{- include "llmEngine.volumeMounts" . | indent 10 }} - serviceAccountName: {{ include "llmEngine.fullname" . }} - {{- include "llmEngine.volumes" . | indent 6 }} + {{- include "modelEngine.cacherEnv" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/charts/llm-engine/templates/cacher_vpa.yaml b/charts/model-engine/templates/cacher_vpa.yaml similarity index 74% rename from charts/llm-engine/templates/cacher_vpa.yaml rename to charts/model-engine/templates/cacher_vpa.yaml index 0b79d1d5..4a07b3df 100644 --- a/charts/llm-engine/templates/cacher_vpa.yaml +++ b/charts/model-engine/templates/cacher_vpa.yaml @@ -2,19 +2,19 @@ apiVersion: "autoscaling.k8s.io/v1" kind: VerticalPodAutoscaler metadata: - name: {{ include "llmEngine.cachername" . }} + name: {{ include "modelEngine.cachername" . }} labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} spec: targetRef: apiVersion: "apps/v1" kind: Deployment - name: {{ include "llmEngine.cachername" . }} + name: {{ include "modelEngine.cachername" . }} updatePolicy: updateMode: "Auto" resourcePolicy: containerPolicies: - - containerName: {{ include "llmEngine.cachername" . }} + - containerName: {{ include "modelEngine.cachername" . }} minAllowed: cpu: {{ .Values.autoscaling.vertical.minAllowed.cpu }} memory: {{ .Values.autoscaling.vertical.minAllowed.memory }} diff --git a/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml new file mode 100644 index 00000000..93d359b5 --- /dev/null +++ b/charts/model-engine/templates/celery_autoscaler_stateful_set.yaml @@ -0,0 +1,112 @@ +{{- if .Values.celery_autoscaler.enabled }} +{{- if not .Values.serviceIdentifier }} +{{- $app := include "modelEngine.celeryautoscalername" . }} +{{- $env := .Values.context }} +{{- $tag := .Values.tag }} +{{- $message_broker := .Values.celeryBrokerType }} +{{- $num_shards := .Values.celery_autoscaler.num_shards }} +{{- $broker_name := "redis-elasticache-message-broker-master" }} +{{- if eq $message_broker "sqs" }} +{{ $broker_name = "sqs-message-broker-master" }} +{{- else if eq $message_broker "servicebus" }} +{{ $broker_name = "servicebus-message-broker-master" }} +{{- end }} +apiVersion: apps/v1 +kind: StatefulSet +metadata: + labels: + {{- include "modelEngine.baseLabels" . | nindent 4 }} + {{- include "modelEngine.selectorLabels.celeryAutoscaler" . | nindent 4 }} + name: {{ $app }} +spec: + serviceName: {{ $app }} + replicas: {{ $num_shards }} + selector: + matchLabels: + app: {{ $app }} + template: + metadata: + annotations: + ad.datadoghq.com/main.logs: '[{"service": "{{ $app }}", "source": "python"}]' + sidecar.istio.io/inject: "false" + labels: + {{- include "modelEngine.baseLabels" . | nindent 8 }} + {{- include "modelEngine.selectorLabels.celeryAutoscaler" . | nindent 8 }} + spec: + containers: + - args: + {{- if .Values.datadog.enabled }} + - ddtrace-run + {{- end }} + - python + - -m + - model_engine_server.core.celery.celery_autoscaler + env: + {{- if .Values.aws }} + - name: AWS_PROFILE + value: {{ .Values.aws.profileName }} + - name: AWS_CONFIG_FILE + value: /opt/.aws/config + {{- end }} + - name: DD_TRACE_ENABLED + value: 'false' + - name: DD_SERVICE + value: {{ $app }} + - name: DD_ENV + value: {{ $env }} + - name: DD_VERSION + value: {{ $tag }} + - name: DD_AGENT_HOST + valueFrom: + fieldRef: + fieldPath: status.hostIP + - name: BROKER_NAME + value: {{ $broker_name }} + - name: REDIS_BROKER_NAME + value: {{ $broker_name }} + - name: CELERY_ELASTICACHE_ENABLED + value: {{ (eq $message_broker "elasticache") | squote }} + - name: POD_NAME + valueFrom: + fieldRef: + fieldPath: metadata.name + - name: NUM_SHARDS + value: '{{ $num_shards }}' + {{- if .Values.azure }} + - name: AZURE_CLIENT_ID + value: {{ .Values.azure.client_id }} + - name: AZURE_OBJECT_ID + value: {{ .Values.azure.object_id }} + - name: SERVICEBUS_NAMESPACE + value: {{ .Values.azure.servicebus_namespace }} + {{- end }} + image: "{{ .Values.image.gatewayRepository }}:{{ $tag }}" + imagePullPolicy: Always + name: main + resources: + requests: + cpu: 1000m + {{- if .Values.aws }} + volumeMounts: + - mountPath: /opt/.aws/config + name: config-volume + subPath: config + {{- end }} + {{ with .Values.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + tolerations: + - key: CriticalAddonsOnly + operator: Equal + value: 'true' + effect: NoSchedule + serviceAccountName: {{ include "modelEngine.fullname" $ }} + {{- if .Values.aws }} + volumes: + - configMap: + name: {{ .Values.aws.configMap.name }} + name: config-volume + {{- end}} +{{- end }} +{{- end }} diff --git a/charts/llm-engine/templates/cluster_rolebinding.yaml b/charts/model-engine/templates/cluster_rolebinding.yaml similarity index 58% rename from charts/llm-engine/templates/cluster_rolebinding.yaml rename to charts/model-engine/templates/cluster_rolebinding.yaml index b438ae93..bdafd94b 100644 --- a/charts/llm-engine/templates/cluster_rolebinding.yaml +++ b/charts/model-engine/templates/cluster_rolebinding.yaml @@ -1,11 +1,11 @@ -{{- $serviceAccountName := include "llmEngine.fullname" . }} -{{- $serviceAccountNamespaces := (include "llmEngine.serviceAccountNamespaces" . | fromYaml) }} +{{- $serviceAccountName := include "modelEngine.fullname" . }} +{{- $serviceAccountNamespaces := (include "modelEngine.serviceAccountNamespaces" . | fromYaml) }} apiVersion: rbac.authorization.k8s.io/v1 kind: ClusterRoleBinding metadata: - name: {{ include "llmEngine.fullname" . }} + name: {{ include "modelEngine.fullname" . }} labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} roleRef: apiGroup: rbac.authorization.k8s.io kind: ClusterRole diff --git a/charts/llm-engine/templates/database_init_job.yaml b/charts/model-engine/templates/database_migration_job.yaml similarity index 57% rename from charts/llm-engine/templates/database_init_job.yaml rename to charts/model-engine/templates/database_migration_job.yaml index f743d0b6..183814c6 100644 --- a/charts/llm-engine/templates/database_init_job.yaml +++ b/charts/model-engine/templates/database_migration_job.yaml @@ -1,12 +1,12 @@ -{{- if .Values.secrets.kubernetesDatabaseSecretName }} +{{- if or (.Values.secrets.kubernetesDatabaseSecretName) (.Values.db.runDbMigrationScript) }} apiVersion: batch/v1 kind: Job metadata: - name: {{ include "llmEngine.fullname" . }}-database-setup + name: {{ include "modelEngine.fullname" . }}-database-migration labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} annotations: - "helm.sh/hook": pre-install + "helm.sh/hook": pre-install,pre-upgrade "helm.sh/hook-weight": "-1" "helm.sh/hook-delete-policy": hook-succeeded spec: @@ -16,7 +16,7 @@ spec: metadata: labels: sidecar.istio.io/inject: "false" - {{- include "llmEngine.labels" . | nindent 8 }} + {{- include "modelEngine.labels" . | nindent 8 }} spec: restartPolicy: Never {{- with .Values.imagePullSecrets }} @@ -24,20 +24,19 @@ spec: {{- toYaml . | nindent 8 }} {{- end }} containers: - - name: {{ include "llmEngine.fullname" . }} + - name: {{ include "modelEngine.fullname" . }} image: "{{ .Values.image.gatewayRepository }}:{{ .Values.tag}}" imagePullPolicy: {{ .Values.image.pullPolicy }} command: - dumb-init - -- args: - - python - - -m - - server.llm_engine_server.entrypoints.init_database - {{- include "llmEngine.serviceEnv" . | indent 10 }} - {{- include "llmEngine.volumeMounts" . | indent 10 }} - serviceAccountName: {{ include "llmEngine.fullname" . }} - {{- include "llmEngine.volumes" . | indent 6 }} + - bash + - /workspace/model-engine/model_engine_server/db/migrations/run_database_migration.sh + {{- include "modelEngine.serviceEnvGitTagFromHelmVar" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/charts/llm-engine/templates/endpoint_builder_deployment.yaml b/charts/model-engine/templates/endpoint_builder_deployment.yaml similarity index 57% rename from charts/llm-engine/templates/endpoint_builder_deployment.yaml rename to charts/model-engine/templates/endpoint_builder_deployment.yaml index fbd85a69..2868e87b 100644 --- a/charts/llm-engine/templates/endpoint_builder_deployment.yaml +++ b/charts/model-engine/templates/endpoint_builder_deployment.yaml @@ -1,29 +1,29 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: {{ include "llmEngine.buildername" . }} + name: {{ include "modelEngine.buildername" . }} labels: - {{- include "llmEngine.selectorLabels.builder" . | nindent 4 }} - {{- include "llmEngine.labels" . | nindent 4 }} - tags.datadoghq.com/service: {{ include "llmEngine.buildername" . }} + {{- include "modelEngine.selectorLabels.builder" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} + tags.datadoghq.com/service: {{ include "modelEngine.buildername" . }} spec: replicas: {{ .Values.replicaCount.builder }} selector: matchLabels: - {{- include "llmEngine.selectorLabels.builder" . | nindent 6 }} + {{- include "modelEngine.selectorLabels.builder" . | nindent 6 }} template: metadata: annotations: cluster-autoscaler.kubernetes.io/safe-to-evict: "false" ad.datadoghq.com/main.logs: | [{ - "service": {{ include "llmEngine.buildername" . | quote }}, + "service": {{ include "modelEngine.buildername" . | quote }}, "source": "python" }] labels: - {{- include "llmEngine.selectorLabels.builder" . | nindent 8 }} - {{- include "llmEngine.labels" . | nindent 8 }} - tags.datadoghq.com/service: {{ include "llmEngine.buildername" . }} + {{- include "modelEngine.selectorLabels.builder" . | nindent 8 }} + {{- include "modelEngine.labels" . | nindent 8 }} + tags.datadoghq.com/service: {{ include "modelEngine.buildername" . }} sidecar.istio.io/inject: "false" spec: {{- with .Values.imagePullSecrets }} @@ -31,7 +31,7 @@ spec: {{- toYaml . | nindent 8 }} {{- end }} containers: - - name: {{ include "llmEngine.buildername" . }} + - name: {{ include "modelEngine.buildername" . }} image: "{{ .Values.image.builderRepository }}:{{ .Values.tag}}" imagePullPolicy: {{ .Values.image.pullPolicy }} ports: @@ -46,24 +46,26 @@ spec: command: - dumb-init - -- + {{- if .Values.datadog.enabled }} - ddtrace-run + {{- end }} args: - celery - - --app=server.llm_engine_server.service_builder + - --app=model_engine_server.service_builder - worker - --loglevel=INFO - --concurrency=2 {{- if .Values.serviceIdentifier }} - - --queues=llm-engine-{{ .Values.serviceIdentifier }}.service-builder + - --queues=model-engine-{{ .Values.serviceIdentifier }}-service-builder {{- else }} - - --queues=llm-engine.service-builder + - --queues=model-engine-service-builder {{- end }} resources: {{- toYaml .Values.resources | nindent 12 }} - {{- include "llmEngine.builderEnv" . | indent 10 }} - {{- include "llmEngine.volumeMounts" . | indent 10 }} - serviceAccountName: {{ include "llmEngine.fullname" . }} - {{- include "llmEngine.volumes" . | indent 6 }} + {{- include "modelEngine.builderEnv" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/charts/llm-engine/templates/endpoint_builder_vpa.yaml b/charts/model-engine/templates/endpoint_builder_vpa.yaml similarity index 74% rename from charts/llm-engine/templates/endpoint_builder_vpa.yaml rename to charts/model-engine/templates/endpoint_builder_vpa.yaml index e467e53a..64983d94 100644 --- a/charts/llm-engine/templates/endpoint_builder_vpa.yaml +++ b/charts/model-engine/templates/endpoint_builder_vpa.yaml @@ -2,19 +2,19 @@ apiVersion: "autoscaling.k8s.io/v1" kind: VerticalPodAutoscaler metadata: - name: {{ include "llmEngine.buildername" . }} + name: {{ include "modelEngine.buildername" . }} labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} spec: targetRef: apiVersion: "apps/v1" kind: Deployment - name: {{ include "llmEngine.buildername" . }} + name: {{ include "modelEngine.buildername" . }} updatePolicy: updateMode: "Auto" resourcePolicy: containerPolicies: - - containerName: {{ include "llmEngine.buildername" . }} + - containerName: {{ include "modelEngine.buildername" . }} minAllowed: cpu: {{ .Values.autoscaling.vertical.minAllowed.cpu }} memory: {{ .Values.autoscaling.vertical.minAllowed.memory }} diff --git a/charts/llm-engine/templates/gateway_deployment.yaml b/charts/model-engine/templates/gateway_deployment.yaml similarity index 57% rename from charts/llm-engine/templates/gateway_deployment.yaml rename to charts/model-engine/templates/gateway_deployment.yaml index e2753524..937071b2 100644 --- a/charts/llm-engine/templates/gateway_deployment.yaml +++ b/charts/model-engine/templates/gateway_deployment.yaml @@ -1,11 +1,11 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: {{ include "llmEngine.fullname" . }} + name: {{ include "modelEngine.fullname" . }} labels: - {{- include "llmEngine.selectorLabels.gateway" . | nindent 4 }} - {{- include "llmEngine.labels" . | nindent 4 }} - tags.datadoghq.com/service: {{ include "llmEngine.fullname" . }} + {{- include "modelEngine.selectorLabels.gateway" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} + tags.datadoghq.com/service: {{ include "modelEngine.fullname" . }} spec: {{- if not .Values.autoscaling.horizontal.enabled }} replicas: {{ .Values.replicaCount.gateway }} @@ -17,26 +17,29 @@ spec: maxSurge: 25% selector: matchLabels: - {{- include "llmEngine.selectorLabels.gateway" . | nindent 6 }} + {{- include "modelEngine.selectorLabels.gateway" . | nindent 6 }} template: metadata: annotations: ad.datadoghq.com/main.logs: | [{ - "service": {{ include "llmEngine.fullname" . | quote }}, + "service": {{ include "modelEngine.fullname" . | quote }}, "source": "python" }] + sidecar.istio.io/proxyMemoryLimit: "5Gi" + sidecar.istio.io/proxyMemory: "1Gi" labels: - {{- include "llmEngine.selectorLabels.gateway" . | nindent 8 }} - {{- include "llmEngine.labels" . | nindent 8 }} - tags.datadoghq.com/service: {{ include "llmEngine.fullname" . }} + {{- include "modelEngine.selectorLabels.gateway" . | nindent 8 }} + {{- include "modelEngine.labels" . | nindent 8 }} + tags.datadoghq.com/service: {{ include "modelEngine.fullname" . }} spec: {{- with .Values.imagePullSecrets }} imagePullSecrets: {{- toYaml . | nindent 8 }} {{- end }} + priorityClassName: model-engine-high-priority containers: - - name: {{ include "llmEngine.fullname" . }} + - name: {{ include "modelEngine.fullname" . }} image: "{{ .Values.image.gatewayRepository }}:{{ .Values.tag}}" imagePullPolicy: {{ .Values.image.pullPolicy }} ports: @@ -49,27 +52,22 @@ spec: port: 5000 periodSeconds: 2 failureThreshold: 30 - livenessProbe: - httpGet: - path: /healthz - port: 5000 - initialDelaySeconds: 5 - periodSeconds: 2 - failureThreshold: 10 command: - dumb-init - -- + {{- if .Values.datadog.enabled }} - ddtrace-run + {{- end }} args: - python - -m - - server.llm_engine_server.entrypoints.start_fastapi_server + - model_engine_server.entrypoints.start_fastapi_server resources: {{- toYaml .Values.resources | nindent 12 }} - {{- include "llmEngine.gatewayEnv" . | indent 10 }} - {{- include "llmEngine.volumeMounts" . | indent 10 }} - serviceAccountName: {{ include "llmEngine.fullname" . }} - {{- include "llmEngine.volumes" . | indent 6 }} + {{- include "modelEngine.gatewayEnv" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/charts/llm-engine/templates/gateway_hpa.yaml b/charts/model-engine/templates/gateway_hpa.yaml similarity index 78% rename from charts/llm-engine/templates/gateway_hpa.yaml rename to charts/model-engine/templates/gateway_hpa.yaml index f9cd542e..9238b538 100644 --- a/charts/llm-engine/templates/gateway_hpa.yaml +++ b/charts/model-engine/templates/gateway_hpa.yaml @@ -2,14 +2,14 @@ apiVersion: autoscaling/v2 kind: HorizontalPodAutoscaler metadata: - name: {{ include "llmEngine.fullname" . }} + name: {{ include "modelEngine.fullname" . }} labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} spec: scaleTargetRef: apiVersion: apps/v1 kind: Deployment - name: {{ include "llmEngine.fullname" . }} + name: {{ include "modelEngine.fullname" . }} minReplicas: {{ .Values.autoscaling.horizontal.minReplicas }} maxReplicas: {{ .Values.autoscaling.horizontal.maxReplicas }} metrics: diff --git a/charts/model-engine/templates/gateway_service.yaml b/charts/model-engine/templates/gateway_service.yaml new file mode 100644 index 00000000..1407ebef --- /dev/null +++ b/charts/model-engine/templates/gateway_service.yaml @@ -0,0 +1,18 @@ +apiVersion: v1 +kind: Service +metadata: + name: {{ include "modelEngine.fullname" . }} + labels: + {{- include "modelEngine.labels" . | nindent 4 }} +spec: + type: {{ .Values.service.type }} + ports: + - port: {{ .Values.service.port }} + targetPort: http + protocol: TCP + name: http + {{- with .Values.service.nodePort }} + nodePort: {{ . }} + {{- end }} + selector: + {{- include "modelEngine.selectorLabels.gateway" . | nindent 4 }} diff --git a/charts/llm-engine/templates/gateway_vpa.yaml b/charts/model-engine/templates/gateway_vpa.yaml similarity index 77% rename from charts/llm-engine/templates/gateway_vpa.yaml rename to charts/model-engine/templates/gateway_vpa.yaml index 4e93cd8a..061ed8cf 100644 --- a/charts/llm-engine/templates/gateway_vpa.yaml +++ b/charts/model-engine/templates/gateway_vpa.yaml @@ -2,21 +2,21 @@ apiVersion: "autoscaling.k8s.io/v1" kind: VerticalPodAutoscaler metadata: - name: {{ include "llmEngine.fullname" . }} + name: {{ include "modelEngine.fullname" . }} labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} spec: targetRef: apiVersion: "apps/v1" kind: Deployment - name: {{ include "llmEngine.fullname" . }} + name: {{ include "modelEngine.fullname" . }} updatePolicy: updateMode: {{ .Values.autoscaling.vertical.updateMode }} resourcePolicy: containerPolicies: - containerName: istio-proxy mode: "Off" - - containerName: {{ include "llmEngine.fullname" . }} + - containerName: {{ include "modelEngine.fullname" . }} minAllowed: cpu: {{ .Values.autoscaling.vertical.minAllowed.cpu }} memory: {{ .Values.autoscaling.vertical.minAllowed.memory }} diff --git a/charts/model-engine/templates/inference_framework_config.yaml b/charts/model-engine/templates/inference_framework_config.yaml new file mode 100644 index 00000000..45759d77 --- /dev/null +++ b/charts/model-engine/templates/inference_framework_config.yaml @@ -0,0 +1,18 @@ +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "modelEngine.fullname" . }}-inference-framework-latest-config + labels: + product: common + team: infra + annotations: + "helm.sh/hook": pre-install + "helm.sh/hook-weight": "-2" +data: + deepspeed: "latest" + text_generation_inference: "latest" + vllm: "latest" + vllm_batch: "latest" + vllm_batch_v2: "latest" + lightllm: "latest" + tensorrt_llm: "latest" diff --git a/charts/model-engine/templates/istio-destinationrule.yaml b/charts/model-engine/templates/istio-destinationrule.yaml new file mode 100644 index 00000000..12b51afb --- /dev/null +++ b/charts/model-engine/templates/istio-destinationrule.yaml @@ -0,0 +1,18 @@ +{{- if .Values.destinationrule.enabled -}} +{{- $fullName := include "modelEngine.fullname" . -}} +apiVersion: networking.istio.io/v1beta1 +kind: DestinationRule +metadata: + name: {{ $fullName }} + labels: + {{- include "modelEngine.labels" . | nindent 4}} + {{- with .Values.destinationrule.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + host: "{{ $fullName }}.{{ .Release.Namespace }}.svc.cluster.local" + trafficPolicy: + loadBalancer: + simple: LEAST_REQUEST # Requires later version of Istio, which we have on the new clusters +{{- end }} diff --git a/charts/model-engine/templates/istio-metrics.yaml b/charts/model-engine/templates/istio-metrics.yaml new file mode 100644 index 00000000..4f19cf73 --- /dev/null +++ b/charts/model-engine/templates/istio-metrics.yaml @@ -0,0 +1,36 @@ +{{- if empty .Values.azure }} +apiVersion: telemetry.istio.io/v1alpha1 +kind: Telemetry +metadata: + name: {{ include "modelEngine.fullname" . }}-custom-tags + namespace: istio-system +spec: + metrics: + - overrides: + - match: + metric: REQUEST_COUNT + mode: CLIENT_AND_SERVER + tagOverrides: + request_operation: + value: istio_requestOperation + providers: + - name: prometheus +--- +apiVersion: extensions.istio.io/v1alpha1 +kind: WasmPlugin +metadata: + name: {{ include "modelEngine.fullname" . }}-attributegen + namespace: istio-system +spec: + imagePullPolicy: Always + phase: AUTHN + pluginConfig: + attributes: + - match: + {{- include "modelEngine.istioAttributeMatchConditions" . | nindent 6 }} + output_attribute: istio_requestOperation + selector: + matchLabels: + {{- include "modelEngine.selectorLabels.gateway" . | nindent 6 }} + url: https://storage.googleapis.com/istio-build/proxy/attributegen-359dcd3a19f109c50e97517fe6b1e2676e870c4d.wasm +{{- end }} diff --git a/charts/model-engine/templates/istio-virtualservice.yaml b/charts/model-engine/templates/istio-virtualservice.yaml new file mode 100644 index 00000000..1bd26e14 --- /dev/null +++ b/charts/model-engine/templates/istio-virtualservice.yaml @@ -0,0 +1,31 @@ +{{- if .Values.virtualservice.enabled -}} +{{- $fullName := include "modelEngine.fullname" . -}} +apiVersion: networking.istio.io/v1alpha3 +kind: VirtualService +metadata: + name: {{ $fullName }} + labels: + {{- include "modelEngine.labels" . | nindent 4 }} + {{- with .Values.virtualservice.annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +spec: + hosts: + {{- range .Values.virtualservice.hostDomains }} + - "{{ $fullName }}.{{ . }}" + {{- end }} + gateways: + {{- range .Values.virtualservice.gateways }} + - {{ . | quote }} + {{- end }} + http: + - route: + - destination: + host: "{{ $fullName }}.{{ .Release.Namespace }}.svc.cluster.local" + port: + number: 80 + retries: + attempts: 3 + retryOn: connect-failure,unavailable,gateway-error +{{- end }} diff --git a/charts/model-engine/templates/model_engine_default_priority_class.yaml b/charts/model-engine/templates/model_engine_default_priority_class.yaml new file mode 100644 index 00000000..a2b80367 --- /dev/null +++ b/charts/model-engine/templates/model_engine_default_priority_class.yaml @@ -0,0 +1,13 @@ +{{- if not .Values.serviceIdentifier }} +apiVersion: scheduling.k8s.io/v1 +kind: PriorityClass +metadata: + name: "{{ include "modelEngine.fullname" . }}-default-priority" +value: 1 +{{- if .Values.balloonConfig.reserveHighPriority }} +# This ensures that the default launch pods will never preempt any pods, which means +# they cannot take advantage of the dummy nodes. +preemptionPolicy: Never +{{- end }} +description: "Default Priority Class for Launch" +{{- end }} diff --git a/charts/llm-engine/templates/launch_high_priority_class.yaml b/charts/model-engine/templates/model_engine_high_priority_class.yaml similarity index 53% rename from charts/llm-engine/templates/launch_high_priority_class.yaml rename to charts/model-engine/templates/model_engine_high_priority_class.yaml index dd088b91..5dbfa7f0 100644 --- a/charts/llm-engine/templates/launch_high_priority_class.yaml +++ b/charts/model-engine/templates/model_engine_high_priority_class.yaml @@ -2,7 +2,7 @@ apiVersion: scheduling.k8s.io/v1 kind: PriorityClass metadata: - name: "{{ include "llmEngine.fullname" . }}-high-priority" + name: "{{ include "modelEngine.fullname" . }}-high-priority" value: 100000 -description: "High Priority Class for LLMEngine" +description: "High Priority Class for Launch" {{- end }} diff --git a/charts/llm-engine/templates/launch_low_priority_class.yaml b/charts/model-engine/templates/model_engine_low_priority_class.yaml similarity index 53% rename from charts/llm-engine/templates/launch_low_priority_class.yaml rename to charts/model-engine/templates/model_engine_low_priority_class.yaml index f40db336..71deb6c2 100644 --- a/charts/llm-engine/templates/launch_low_priority_class.yaml +++ b/charts/model-engine/templates/model_engine_low_priority_class.yaml @@ -2,7 +2,7 @@ apiVersion: scheduling.k8s.io/v1 kind: PriorityClass metadata: - name: "{{ include "llmEngine.fullname" . }}-low-priority" + name: "{{ include "modelEngine.fullname" . }}-low-priority" value: 0 -description: "Low Priority Class for LLMEngine" +description: "Low Priority Class for Launch" {{- end }} diff --git a/charts/model-engine/templates/pod_disruption_budget.yaml b/charts/model-engine/templates/pod_disruption_budget.yaml new file mode 100644 index 00000000..67b7a02f --- /dev/null +++ b/charts/model-engine/templates/pod_disruption_budget.yaml @@ -0,0 +1,17 @@ +{{- if and .Values.podDisruptionBudget .Values.podDisruptionBudget.enabled }} +apiVersion: policy/v1 +kind: PodDisruptionBudget +metadata: + name: {{ include "modelEngine.fullname" . }} + labels: + {{- include "modelEngine.labels" . | nindent 4 }} +spec: + {{- if .Values.podDisruptionBudget.minAvailable }} + minAvailable: {{ .Values.podDisruptionBudget.minAvailable }} + {{- else }} + maxUnavailable: {{ .Values.podDisruptionBudget.maxUnavailable }} + {{- end }} + selector: + matchLabels: + {{- include "modelEngine.selectorLabels.gateway" . | nindent 6 }} +{{- end }} diff --git a/charts/model-engine/templates/populate_fine_tuning_repository_job.yaml b/charts/model-engine/templates/populate_fine_tuning_repository_job.yaml new file mode 100644 index 00000000..080f21e6 --- /dev/null +++ b/charts/model-engine/templates/populate_fine_tuning_repository_job.yaml @@ -0,0 +1,58 @@ +{{- if .Values.populateFineTuningRepository }} +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ include "modelEngine.fullname" . }}-populate-fine-tuning-repository + labels: + {{- include "modelEngine.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": post-install + "helm.sh/hook-weight": "1" + "helm.sh/hook-delete-policy": hook-succeeded +spec: + backoffLimit: 0 + activeDeadlineSeconds: 600 + template: + metadata: + labels: + sidecar.istio.io/inject: "false" + {{- include "modelEngine.labels" . | nindent 8 }} + spec: + restartPolicy: Never + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + containers: + - name: {{ include "modelEngine.fullname" . }} + image: "{{ .Values.image.gatewayRepository }}:{{ .Values.tag }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + command: + - dumb-init + - -- + args: + - python + - -m + - model_engine_server.entrypoints.populate_llm_fine_tuning_job_repository + {{- if .Values.azure }} + - --cloud-provider + - azure + {{- end }} + - --initialize-repository + {{- include "modelEngine.serviceEnvGitTagFromHelmVar" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} + {{- with .Values.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} +{{- end }} diff --git a/charts/llm-engine/templates/proportional_a100_autoscaler_deployment.yaml b/charts/model-engine/templates/proportional_a100_autoscaler_deployment.yaml similarity index 79% rename from charts/llm-engine/templates/proportional_a100_autoscaler_deployment.yaml rename to charts/model-engine/templates/proportional_a100_autoscaler_deployment.yaml index f288bdf1..f89f298e 100644 --- a/charts/llm-engine/templates/proportional_a100_autoscaler_deployment.yaml +++ b/charts/model-engine/templates/proportional_a100_autoscaler_deployment.yaml @@ -3,19 +3,19 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: llm-engine-proportional-a100-autoscaler-deployment + name: {{ .Chart.Name }}-proportional-a100-autoscaler-deployment labels: team: infra product: common-warm-nodes spec: selector: matchLabels: - app: llm-engine-proportional-a100-autoscaler-deployment + app: {{ .Chart.Name }}-proportional-a100-autoscaler-deployment version: v1 template: metadata: labels: - app: llm-engine-proportional-a100-autoscaler-deployment + app: {{ .Chart.Name }}-proportional-a100-autoscaler-deployment product: common-warm-nodes team: infra env: {{ .Values.context }} @@ -38,12 +38,12 @@ spec: - /cluster-proportional-autoscaler - --namespace={{ .Release.Namespace }} - --configmap=cluster-proportional-autoscaler - - --target=deployment/llm-engine-balloon-a100 + - --target=deployment/{{ .Chart.Name }}-balloon-a100 - --default-params={"linear":{"nodesPerReplica":10,"preventSinglePointFailure":false,"includeUnschedulableNodes":false}} - --nodelabels=k8s.amazonaws.com/accelerator=nvidia-ampere-a100 - --logtostderr=true - --v=2 priorityClassName: system-cluster-critical - serviceAccountName: {{ include "llmEngine.fullname" . }} + serviceAccountName: {{ include "modelEngine.fullname" . }} {{- end }} {{- end }} diff --git a/charts/llm-engine/templates/proportional_a10_autoscaler_deployment.yaml b/charts/model-engine/templates/proportional_a10_autoscaler_deployment.yaml similarity index 79% rename from charts/llm-engine/templates/proportional_a10_autoscaler_deployment.yaml rename to charts/model-engine/templates/proportional_a10_autoscaler_deployment.yaml index d6fd7594..70274d26 100644 --- a/charts/llm-engine/templates/proportional_a10_autoscaler_deployment.yaml +++ b/charts/model-engine/templates/proportional_a10_autoscaler_deployment.yaml @@ -3,19 +3,19 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: llm-engine-proportional-a10-autoscaler-deployment + name: {{ .Chart.Name }}-proportional-a10-autoscaler-deployment labels: team: infra product: common-warm-nodes spec: selector: matchLabels: - app: llm-engine-proportional-a10-autoscaler-deployment + app: {{ .Chart.Name }}-proportional-a10-autoscaler-deployment version: v1 template: metadata: labels: - app: llm-engine-proportional-a10-autoscaler-deployment + app: {{ .Chart.Name }}-proportional-a10-autoscaler-deployment product: common-warm-nodes team: infra env: {{ .Values.context }} @@ -38,12 +38,12 @@ spec: - /cluster-proportional-autoscaler - --namespace={{ .Release.Namespace }} - --configmap=cluster-proportional-autoscaler - - --target=deployment/llm-engine-balloon-a10 + - --target=deployment/{{ .Chart.Name }}-balloon-a10 - --default-params={"linear":{"nodesPerReplica":10,"preventSinglePointFailure":false,"includeUnschedulableNodes":false}} - --nodelabels=k8s.amazonaws.com/accelerator=nvidia-ampere-a10 - --logtostderr=true - --v=2 priorityClassName: system-cluster-critical - serviceAccountName: {{ include "llmEngine.fullname" . }} + serviceAccountName: {{ include "modelEngine.fullname" . }} {{- end }} {{- end }} diff --git a/charts/llm-engine/templates/proportional_t4_autoscaler_deployment.yaml b/charts/model-engine/templates/proportional_t4_autoscaler_deployment.yaml similarity index 79% rename from charts/llm-engine/templates/proportional_t4_autoscaler_deployment.yaml rename to charts/model-engine/templates/proportional_t4_autoscaler_deployment.yaml index 29e5a8e9..7175d985 100644 --- a/charts/llm-engine/templates/proportional_t4_autoscaler_deployment.yaml +++ b/charts/model-engine/templates/proportional_t4_autoscaler_deployment.yaml @@ -3,19 +3,19 @@ apiVersion: apps/v1 kind: Deployment metadata: - name: llm-engine-proportional-t4-autoscaler-deployment + name: {{ .Chart.Name }}-proportional-t4-autoscaler-deployment labels: team: infra product: common-warm-nodes spec: selector: matchLabels: - app: llm-engine-proportional-t4-autoscaler-deployment + app: {{ .Chart.Name }}-proportional-t4-autoscaler-deployment version: v1 template: metadata: labels: - app: llm-engine-proportional-t4-autoscaler-deployment + app: {{ .Chart.Name }}-proportional-t4-autoscaler-deployment product: common-warm-nodes team: infra env: {{ .Values.context }} @@ -38,12 +38,12 @@ spec: - /cluster-proportional-autoscaler - --namespace={{ .Release.Namespace }} - --configmap=cluster-proportional-autoscaler - - --target=deployment/llm-engine-balloon-t4 + - --target=deployment/{{ .Chart.Name }}-balloon-t4 - --default-params={"linear":{"nodesPerReplica":10,"preventSinglePointFailure":false,"includeUnschedulableNodes":false}} - --nodelabels=k8s.amazonaws.com/accelerator=nvidia-tesla-t4 - --logtostderr=true - --v=2 priorityClassName: system-cluster-critical - serviceAccountName: {{ include "llmEngine.fullname" . }} + serviceAccountName: {{ include "modelEngine.fullname" . }} {{- end }} {{- end }} diff --git a/charts/model-engine/templates/recommended_hardware_config_map.yaml b/charts/model-engine/templates/recommended_hardware_config_map.yaml new file mode 100644 index 00000000..b185a5ab --- /dev/null +++ b/charts/model-engine/templates/recommended_hardware_config_map.yaml @@ -0,0 +1,30 @@ +{{ if .Values.recommendedHardware }} +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "modelEngine.fullname" . }}-recommended-hardware-config + labels: + product: common + team: infra +data: + byGpuMemoryGb: |- +{{- range $.Values.recommendedHardware.byGpuMemoryGb }} + - gpu_memory_le: {{ .gpu_memory_le }} + cpus: {{ .cpus }} + gpus: {{ .gpus }} + memory: {{ .memory }} + storage: {{ .storage }} + gpu_type: {{ .gpu_type }} + nodes_per_worker: {{ .nodes_per_worker }} +{{- end }} + byModelName: |- +{{- range $.Values.recommendedHardware.byModelName }} + - name: {{ .name }} + cpus: {{ .cpus }} + gpus: {{ .gpus }} + memory: {{ .memory }} + storage: {{ .storage }} + gpu_type: {{ .gpu_type }} + nodes_per_worker: {{ .nodes_per_worker }} +{{- end }} +{{- end }} diff --git a/charts/model-engine/templates/restart_keda_operator.yaml b/charts/model-engine/templates/restart_keda_operator.yaml new file mode 100644 index 00000000..8937ea82 --- /dev/null +++ b/charts/model-engine/templates/restart_keda_operator.yaml @@ -0,0 +1,57 @@ +# needed for the Azure bicep deployment due to using the default keda installation and a workload identity for auth +# see note in https://learn.microsoft.com/en-us/azure/aks/keda-deploy-add-on-arm +# keda-operator pods need AZURE_TENANT_ID, AZURE_FEDERATED_TOKEN_FILE, and AZURE_AUTHORITY_HOST env vars injected +{{- if .Values.restartKedaOperator }} +apiVersion: batch/v1 +kind: Job +metadata: + name: {{ include "modelEngine.fullname" . }}-restart-keda-operator + labels: + {{- include "modelEngine.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": post-install + "helm.sh/hook-weight": "1" + "helm.sh/hook-delete-policy": hook-succeeded +spec: + backoffLimit: 0 + activeDeadlineSeconds: 600 + template: + metadata: + labels: + sidecar.istio.io/inject: "false" + {{- include "modelEngine.labels" . | nindent 8 }} + spec: + restartPolicy: Never + {{- with .Values.imagePullSecrets }} + imagePullSecrets: + {{- toYaml . | nindent 8 }} + {{- end }} + containers: + - name: {{ include "modelEngine.fullname" . }} + image: "{{ .Values.image.gatewayRepository }}:{{ .Values.tag }}" + imagePullPolicy: {{ .Values.image.pullPolicy }} + command: + - kubectl + - rollout + - restart + - deployment + - keda-operator + - -n + - kube-system + {{- include "modelEngine.serviceEnvGitTagFromHelmVar" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} + {{- with .Values.nodeSelector }} + nodeSelector: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.affinity }} + affinity: + {{- toYaml . | nindent 8 }} + {{- end }} + {{- with .Values.tolerations }} + tolerations: + {{- toYaml . | nindent 8 }} + {{- end }} +{{- end }} diff --git a/charts/llm-engine/templates/service_account.yaml b/charts/model-engine/templates/service_account.yaml similarity index 50% rename from charts/llm-engine/templates/service_account.yaml rename to charts/model-engine/templates/service_account.yaml index 73be82d7..dc41c998 100644 --- a/charts/llm-engine/templates/service_account.yaml +++ b/charts/model-engine/templates/service_account.yaml @@ -1,7 +1,7 @@ -{{- $serviceAccountName := include "llmEngine.fullname" . }} -{{- $serviceAccountNamespaces := (include "llmEngine.serviceAccountNamespaces" . | fromYaml) }} +{{- $serviceAccountName := include "modelEngine.fullname" . }} +{{- $serviceAccountNamespaces := (include "modelEngine.serviceAccountNamespaces" . | fromYaml) }} {{- $annotations := .Values.serviceAccount.annotations }} -{{- $labels := include "llmEngine.labels" . }} +{{- $labels := include "modelEngine.labels" . }} {{- range $namespace := (index $serviceAccountNamespaces "namespaces") }} apiVersion: v1 kind: ServiceAccount @@ -13,6 +13,13 @@ metadata: {{- with $annotations }} annotations: {{- toYaml . | nindent 4 }} + {{- if $.Values.azure }} + azure.workload.identity/client-id: {{ $.Values.azure.client_id }} + {{- end }} {{- end }} +{{- if $.Values.azure }} +imagePullSecrets: + - name: egp-ecr-regcred +{{- end }} --- {{- end }} diff --git a/charts/model-engine/templates/service_account_image_builder.yaml b/charts/model-engine/templates/service_account_image_builder.yaml new file mode 100644 index 00000000..e68cd7b2 --- /dev/null +++ b/charts/model-engine/templates/service_account_image_builder.yaml @@ -0,0 +1,19 @@ +{{- if and (.Values.imageBuilderServiceAccount) (.Values.imageBuilderServiceAccount.create) }} +{{- $serviceAccountNamespaces := (include "modelEngine.serviceAccountNamespaces" . | fromYaml) }} +{{- $annotations := .Values.imageBuilderServiceAccount.annotations }} +{{- $labels := include "modelEngine.labels" . }} +{{- range $namespace := (index $serviceAccountNamespaces "namespaces") }} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: kaniko + namespace: {{- printf " %s" $namespace }} + labels: + {{- $labels | nindent 4 }} + {{- with $annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- end }} +--- +{{- end }} +{{- end }} diff --git a/charts/model-engine/templates/service_account_inference.yaml b/charts/model-engine/templates/service_account_inference.yaml new file mode 100644 index 00000000..9a4a698c --- /dev/null +++ b/charts/model-engine/templates/service_account_inference.yaml @@ -0,0 +1,25 @@ +{{- if and (.Values.serviceTemplate) (.Values.serviceTemplate.createServiceAccount) (.Values.serviceTemplate.serviceAccountAnnotations) (.Values.serviceTemplate.serviceAccountName) (.Values.config.values.launch.endpoint_namespace)}} +{{- $annotations := .Values.serviceTemplate.serviceAccountAnnotations }} +{{- $inferenceServiceAccountName := .Values.serviceTemplate.serviceAccountName }} +{{- $inferenceServiceAccountNamespace := .Values.config.values.launch.endpoint_namespace }} +{{- $labels := include "modelEngine.labels" . }} +apiVersion: v1 +kind: ServiceAccount +metadata: + name: {{- printf " %s" $inferenceServiceAccountName }} + namespace: {{- printf " %s" $inferenceServiceAccountNamespace }} + labels: + {{- $labels | nindent 4 }} + {{- with $annotations }} + annotations: + {{- toYaml . | nindent 4 }} + {{- if $.Values.azure }} + azure.workload.identity/client-id: {{ $.Values.azure.client_id }} + {{- end }} + {{- end }} +{{- if $.Values.azure }} +imagePullSecrets: + - name: egp-ecr-regcred +{{- end }} +--- +{{- end }} diff --git a/charts/model-engine/templates/service_config_map.yaml b/charts/model-engine/templates/service_config_map.yaml new file mode 100644 index 00000000..403bb552 --- /dev/null +++ b/charts/model-engine/templates/service_config_map.yaml @@ -0,0 +1,56 @@ +{{- if .Values.config.values }} +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "modelEngine.fullname" . }}-service-config + labels: + {{- include "modelEngine.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-2" +data: + launch_service_config: |- + dd_trace_enabled: {{ .Values.dd_trace_enabled | default false | quote }} + gateway_namespace: {{ .Release.Namespace | quote }} + {{- with .Values.config.values.launch }} + {{- range $key, $value := . }} + {{ $key }}: {{ $value | quote }} + {{- end }} + {{- end }} + infra_service_config: |- + env: {{ .Values.context | quote }} + {{- with .Values.config.values.infra }} + {{- range $key, $value := . }} + {{ $key }}: {{ $value | quote }} + {{- end }} + {{- end }} + +--- + +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ include "modelEngine.fullname" . }}-service-config + namespace: {{ .Values.config.values.launch.endpoint_namespace }} + labels: + {{- include "modelEngine.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-2" +data: + launch_service_config: |- + dd_trace_enabled: {{ .Values.dd_trace_enabled | default false | quote }} + gateway_namespace: {{ .Release.Namespace | quote }} + {{- with .Values.config.values.launch }} + {{- range $key, $value := . }} + {{ $key }}: {{ $value | quote }} + {{- end }} + {{- end }} + infra_service_config: |- + env: {{ .Values.context | quote }} + {{- with .Values.config.values.infra }} + {{- range $key, $value := . }} + {{ $key }}: {{ $value | quote }} + {{- end }} + {{- end }} +{{- end }} diff --git a/charts/model-engine/templates/service_template_config_map.yaml b/charts/model-engine/templates/service_template_config_map.yaml new file mode 100644 index 00000000..6836784c --- /dev/null +++ b/charts/model-engine/templates/service_template_config_map.yaml @@ -0,0 +1,1314 @@ +{{- $launch_name := include "modelEngine.fullname" . }} +{{- $config_values := .Values.config.values }} +{{- $forwarder_repository := .Values.image.forwarderRepository -}} +{{- $triton_repository := .Values.triton.image.repository -}} +{{- $triton_tag := .Values.triton.image.tag -}} +{{- $env := .Values.context -}} +{{- $service_template_labels := include "modelEngine.serviceTemplateLabels" . }} +{{- $job_template_labels := include "modelEngine.jobTemplateLabels" . }} +{{- $service_env := include "modelEngine.serviceEnvGitTagFromPythonReplace" . }} +{{- $async_service_template_env := include "modelEngine.asyncServiceTemplateEnv" . }} +{{- $sync_service_template_env := include "modelEngine.syncServiceTemplateEnv" . }} +{{- $async_forwarder_template_env := include "modelEngine.asyncForwarderTemplateEnv" . }} +{{- $sync_forwarder_template_env := include "modelEngine.syncForwarderTemplateEnv" . }} +{{- $forwarder_volume_mounts := include "modelEngine.forwarderVolumeMounts" . }} +{{- $gateway_repository := .Values.image.gatewayRepository -}} +{{- $tag := .Values.tag -}} +{{- $aws_config_map_name := (.Values.aws).configMap.name }} +{{- $security_context := .Values.serviceTemplate.securityContext }} +{{- $mount_infra_config := .Values.serviceTemplate.mountInfraConfig }} +{{- $service_template_service_account_name := .Values.serviceTemplate.serviceAccountName }} +{{- $service_template_aws_config_map_name := .Values.serviceTemplate.awsConfigMapName }} +{{- $celery_broker_type := .Values.celeryBrokerType }} +{{- $node_selector := .Values.nodeSelector }} +{{- $require_aws_config := not (empty .Values.aws) }} +{{- $enable_datadog := .Values.datadog.enabled }} +{{- $azure_cloud_provider := not (empty .Values.azure) }} + +{{- if .Values.message }} +{{- .Values.message }} +{{- end }} +apiVersion: v1 +kind: ConfigMap +metadata: + name: {{ $launch_name }}-service-template-config + labels: + {{- include "modelEngine.labels" . | nindent 4 }} + annotations: + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-2" +data: + {{- range $device := tuple "cpu" "gpu" }} + {{- range $mode := tuple "async" "sync" "streaming"}} + {{- range $flavor := tuple "triton-enhanced-runnable-image" "runnable-image" }} + {{- if or (ne $mode "streaming") (eq $flavor "runnable-image") }} + deployment-{{ $flavor }}-{{ $mode }}-{{ $device }}.yaml: |- + apiVersion: apps/v1 + kind: Deployment + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + {{- if eq $mode "async" }} + annotations: + {{- include "modelEngine.serviceTemplateAsyncAnnotations" . | nindent 8 }} + {{- end }} + spec: + strategy: + type: RollingUpdate + rollingUpdate: + maxSurge: 1 + maxUnavailable: 0 + replicas: ${MIN_WORKERS} + selector: + matchLabels: + app: ${RESOURCE_NAME} + version: v1 + template: + metadata: + labels: + app: ${RESOURCE_NAME} + {{- $service_template_labels | nindent 12 }} + {{- if eq $mode "async" }} + sidecar.istio.io/inject: "false" # TODO: switch to scuttle + {{- end }} + version: v1 + annotations: + ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' + kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" + spec: + affinity: + {{- include "modelEngine.serviceTemplateAffinity" . | nindent 12 }} + {{- if eq $mode "async" }} + terminationGracePeriodSeconds: 1800 + {{- else }} + terminationGracePeriodSeconds: 600 + {{- end }} + {{- if $service_template_service_account_name }} + serviceAccount: {{ $service_template_service_account_name }} + {{- else }} + serviceAccount: {{ $launch_name }} + {{- end }} + {{- with $node_selector }} + nodeSelector: + {{- toYaml . | nindent 12 }} + {{- end }} + {{- if eq $device "gpu" }} + {{- if empty $node_selector }} + nodeSelector: + {{- end }} + k8s.amazonaws.com/accelerator: ${GPU_TYPE} + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + {{- end }} + priorityClassName: ${PRIORITY} + containers: + {{- if contains "runnable-image" $flavor }} + {{- if eq $mode "sync" }} + - name: http-forwarder + image: {{ $forwarder_repository }}:${GIT_TAG} + imagePullPolicy: IfNotPresent + command: + - /usr/bin/dumb-init + - -- + {{- if $enable_datadog }} + - ddtrace-run + {{- end }} + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder + - --config + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - --port + - "${FORWARDER_PORT}" + - --num-workers + - "${FORWARDER_WORKER_COUNT}" + - --set + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + {{- $sync_forwarder_template_env | nindent 14 }} + readinessProbe: + httpGet: + path: /readyz + port: ${FORWARDER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: "100M" + ephemeral-storage: "100M" + limits: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: ${FORWARDER_MEMORY_LIMIT} + ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} + {{ $forwarder_volume_mounts | nindent 14 }} + ports: + - containerPort: ${FORWARDER_PORT} + name: http + {{- else if eq $mode "streaming" }} + - name: http-forwarder + image: {{ $forwarder_repository }}:${GIT_TAG} + imagePullPolicy: IfNotPresent + command: + - /usr/bin/dumb-init + - -- + {{- if $enable_datadog }} + - ddtrace-run + {{- end }} + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder + - --config + - /workspace/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml + - --port + - "${FORWARDER_PORT}" + - --num-workers + - "${FORWARDER_WORKER_COUNT}" + - --set + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.stream.predict_route=${STREAMING_PREDICT_ROUTE}" + - --set + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" + - --set + - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" + {{- $sync_forwarder_template_env | nindent 14 }} + readinessProbe: + httpGet: + path: /readyz + port: ${FORWARDER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: "100M" + ephemeral-storage: "100M" + limits: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: ${FORWARDER_MEMORY_LIMIT} + ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} + {{ $forwarder_volume_mounts | nindent 14 }} + ports: + - containerPort: ${FORWARDER_PORT} + name: http + {{- else if eq $mode "async" }} + - name: celery-forwarder + image: {{ $forwarder_repository }}:${GIT_TAG} + imagePullPolicy: IfNotPresent + command: + - /usr/bin/dumb-init + - -- + {{- if $enable_datadog }} + - ddtrace-run + {{- end }} + - python + - -m + - model_engine_server.inference.forwarding.celery_forwarder + - --config + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - --queue + - "${QUEUE}" + - --task-visibility + - "VISIBILITY_24H" + - --set + - "forwarder.async.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" + {{- if eq $celery_broker_type "sqs" }} + - --sqs-url + - "${SQS_QUEUE_URL}" + {{- end }} + - --num-workers + - "${PER_WORKER}" + - --broker-type + - {{ $celery_broker_type }} + {{- if eq $celery_broker_type "servicebus" }} + - --backend-protocol + - abs + {{- end }} + {{- $async_forwarder_template_env | nindent 14 }} + resources: + requests: + cpu: 0.1 + memory: "100M" + ephemeral-storage: "100M" + limits: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: ${FORWARDER_MEMORY_LIMIT} + ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} + {{ $forwarder_volume_mounts | nindent 14 }} + {{- end }} + {{- if eq $flavor "triton-enhanced-runnable-image" }} + - name: tritonserver + image: {{ $triton_repository }}:${TRITON_COMMIT_TAG}-triton + imagePullPolicy: IfNotPresent + command: + - /usr/bin/dumb-init + - -- + - bash + - -c + - "$TRITON_COMMAND" + env: + - name: AWS_PROFILE + value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" + ports: + - containerPort: 8000 + name: http + - containerPort: 8001 + name: grpc + - containerPort: 8002 + name: metrics + readinessProbe: + httpGet: + # Need to have Triton support --http-address IPv6 :( + # https://github:com/triton-inference-server/server/issues/5305: + # path: /v2/health/ready + # port: 8000 + path: /readyz + port: 3000 + initialDelaySeconds: $TRITON_READINESS_INITIAL_DELAY + periodSeconds: 10 + resources: + requests: + cpu: ${TRITON_CPUS} + ${TRITON_MEMORY_DICT} + ${TRITON_STORAGE_DICT} + limits: + cpu: ${TRITON_CPUS} + ${TRITON_MEMORY_DICT} + ${TRITON_STORAGE_DICT} + volumeMounts: + {{- if $require_aws_config }} + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + {{- end }} + - mountPath: /dev/shm + name: dshm + {{- end }} + - name: main + {{- with $security_context }} + securityContext: + {{- toYaml . | nindent 16 }} + {{- end }} + image: ${IMAGE} + imagePullPolicy: IfNotPresent + command: ${COMMAND} + env: ${MAIN_ENV} + readinessProbe: + httpGet: + path: ${HEALTHCHECK_ROUTE} + port: ${USER_CONTAINER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + limits: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + volumeMounts: + {{- if $require_aws_config }} + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + {{- end }} + - mountPath: /dev/shm + name: dshm + {{- if $mount_infra_config }} + - name: infra-service-config-volume + mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} + {{- end }} + - name: user-config + mountPath: /app/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /app/endpoint_config + subPath: raw_data + ports: + - containerPort: ${USER_CONTAINER_PORT} + name: http + {{- end }} + # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 + securityContext: + fsGroup: 65534 + volumes: + {{- if $require_aws_config }} + - name: config-volume + configMap: + {{- if $service_template_aws_config_map_name }} + name: {{ $service_template_aws_config_map_name }} + {{- else }} + name: {{ $aws_config_map_name }} + {{- end }} + {{- end }} + - name: user-config + configMap: + name: ${RESOURCE_NAME} + - name: endpoint-config + configMap: + name: ${RESOURCE_NAME}-endpoint-config + - name: dshm + emptyDir: + medium: Memory + {{- if $config_values }} + - name: infra-service-config-volume + configMap: + name: {{ $launch_name }}-service-config + items: + - key: infra_service_config + path: config.yaml + {{- end }} + {{- end }} + {{- end }} + {{- end }} + {{- end }} + user-config.yaml: |- + apiVersion: v1 + kind: ConfigMap + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + data: + raw_data: ${CONFIG_DATA_SERIALIZED} + endpoint-config.yaml: |- + apiVersion: v1 + kind: ConfigMap + metadata: + name: ${RESOURCE_NAME}-endpoint-config + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + data: + raw_data: ${ENDPOINT_CONFIG_SERIALIZED} + horizontal-pod-autoscaler.yaml: |- + apiVersion: ${API_VERSION} + kind: HorizontalPodAutoscaler + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + minReplicas: ${MIN_WORKERS} + maxReplicas: ${MAX_WORKERS} + scaleTargetRef: + apiVersion: apps/v1 + kind: Deployment + name: ${RESOURCE_NAME} + metrics: + - type: Pods + pods: + metric: + name: request-concurrency-average + target: + type: Value + averageValue: ${CONCURRENCY} + keda-scaled-object.yaml: |- + apiVersion: keda.sh/v1alpha1 + kind: ScaledObject + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + scaleTargetRef: + name: ${RESOURCE_NAME} + pollingInterval: 5 + cooldownPeriod: 300 + minReplicaCount: ${MIN_WORKERS} + maxReplicaCount: ${MAX_WORKERS} + fallback: + failureThreshold: 3 + replicas: ${MIN_WORKERS} + triggers: + {{- if $azure_cloud_provider }} + - type: azure-servicebus + metadata: + queueName: "launch-endpoint-autoscaling.${ENDPOINT_ID}" + namespace: ${SERVICEBUS_NAMESPACE} + messageCount: "100" + activationMessageCount: "0" + authenticationRef: + name: "${AUTHENTICATION_REF}" + {{- else }} + - type: redis + metadata: + address: ${REDIS_HOST_PORT} # Format must be host:port + {{- if not .Values.redis.enableAuth }} + passwordFromEnv: "" + {{- end }} + listName: "launch-endpoint-autoscaling:${ENDPOINT_ID}" + listLength: "100" # something absurdly high so we don't scale past 1 pod + activationListLength: "0" + enableTLS: "{{ .Values.redis.enableTLS }}" + unsafeSsl: "{{ .Values.redis.unsafeSsl }}" + databaseIndex: "${REDIS_DB_INDEX}" + {{- if .Values.redis.enableAuth }} + authenticationRef: + name: "keda-trigger-auth-redis-secret" + {{- end }} + {{- end }} + - type: prometheus + metadata: + threshold: "${CONCURRENCY}" + metricName: request_concurrency_average + query: sum(rate(istio_request_duration_milliseconds_sum{destination_workload="${RESOURCE_NAME}"}[2m])) / 1000 + serverAddress: ${PROMETHEUS_SERVER_ADDRESS} + {{- range $device := tuple "gpu" }} + {{- range $mode := tuple "streaming"}} + leader-worker-set-{{ $mode }}-{{ $device }}.yaml: |- + apiVersion: leaderworkerset.x-k8s.io/v1 + kind: LeaderWorkerSet + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + replicas: ${MIN_WORKERS} + leaderWorkerTemplate: + size: ${LWS_SIZE} + restartPolicy: RecreateGroupOnPodRestart # TODO un-hardcode? if necessary + leaderTemplate: + metadata: + labels: + app: ${RESOURCE_NAME} + role: leader + {{- $service_template_labels | nindent 14 }} + sidecar.istio.io/inject: "false" # Never inject istio, it screws up networking + version: v1 + annotations: + ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' + kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" + spec: + affinity: + {{- include "modelEngine.serviceTemplateAffinity" . | nindent 14 }} + {{- if eq $mode "async" }} # TODO + terminationGracePeriodSeconds: 1800 + {{- else }} + terminationGracePeriodSeconds: 600 + {{- end }} + {{- if $service_template_service_account_name }} + serviceAccount: {{ $service_template_service_account_name }} + {{- else }} + serviceAccount: {{ $launch_name }} + {{- end }} + {{- with $node_selector }} + nodeSelector: + {{- toYaml . | nindent 14 }} + {{- end }} + {{- if eq $device "gpu" }} + {{- if empty $node_selector }} + nodeSelector: + {{- end }} + k8s.amazonaws.com/accelerator: ${GPU_TYPE} + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + {{- end }} + priorityClassName: ${PRIORITY} + containers: + {{- if eq $mode "sync" }} + - name: http-forwarder + image: {{ $forwarder_repository }}:${GIT_TAG} + imagePullPolicy: IfNotPresent + command: + - /usr/bin/dumb-init + - -- + {{- if $enable_datadog }} + - ddtrace-run + {{- end }} + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder + - --config + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - --port + - "${FORWARDER_PORT}" + - --num-workers + - "${FORWARDER_WORKER_COUNT}" + - --set + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + {{- $sync_forwarder_template_env | nindent 16 }} + readinessProbe: + httpGet: + path: /readyz + port: ${FORWARDER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: "100M" + ephemeral-storage: "100M" + limits: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: ${FORWARDER_MEMORY_LIMIT} + ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} + {{ $forwarder_volume_mounts | nindent 16 }} + ports: + - containerPort: ${FORWARDER_PORT} + name: http + {{- else if eq $mode "streaming" }} + - name: http-forwarder + image: {{ $forwarder_repository }}:${GIT_TAG} + imagePullPolicy: IfNotPresent + command: + - /usr/bin/dumb-init + - -- + {{- if $enable_datadog }} + - ddtrace-run + {{- end }} + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder + - --config + - /workspace/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml + - --port + - "${FORWARDER_PORT}" + - --num-workers + - "${FORWARDER_WORKER_COUNT}" + - --set + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.stream.predict_route=${STREAMING_PREDICT_ROUTE}" + - --set + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + {{- $sync_forwarder_template_env | nindent 16 }} + readinessProbe: + httpGet: + path: /readyz + port: ${FORWARDER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: "100M" + ephemeral-storage: "100M" + limits: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: ${FORWARDER_MEMORY_LIMIT} + ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} + {{ $forwarder_volume_mounts | nindent 16 }} + ports: + - containerPort: ${FORWARDER_PORT} + name: http + {{- else if eq $mode "async" }} + - name: celery-forwarder + image: {{ $forwarder_repository }}:${GIT_TAG} + imagePullPolicy: IfNotPresent + command: + - /usr/bin/dumb-init + - -- + {{- if $enable_datadog }} + - ddtrace-run + {{- end }} + - python + - -m + - model_engine_server.inference.forwarding.celery_forwarder + - --config + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - --queue + - "${QUEUE}" + - --task-visibility + - "VISIBILITY_24H" + - --set + - "forwarder.async.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" + {{- if eq $celery_broker_type "sqs" }} + - --sqs-url + - "${SQS_QUEUE_URL}" + {{- end }} + - --num-workers + - "${PER_WORKER}" + - --broker-type + - {{ $celery_broker_type }} + {{- if eq $celery_broker_type "servicebus" }} + - --backend-protocol + - abs + {{- end }} + {{- $async_forwarder_template_env | nindent 16 }} + resources: + requests: + cpu: 0.1 + memory: "100M" + ephemeral-storage: "100M" + limits: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: ${FORWARDER_MEMORY_LIMIT} + ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} + {{ $forwarder_volume_mounts | nindent 16 }} + {{- end }} + - name: lws-leader + image: ${IMAGE} + imagePullPolicy: IfNotPresent + command: ${COMMAND} + env: ${MAIN_ENV} + readinessProbe: + httpGet: + path: ${HEALTHCHECK_ROUTE} + port: ${USER_CONTAINER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + limits: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + volumeMounts: + {{- if $require_aws_config }} + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + {{- end }} + - mountPath: /dev/shm + name: dshm + {{- if $mount_infra_config }} + - name: infra-service-config-volume + mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} + {{- end }} + - name: user-config + mountPath: /app/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /app/endpoint_config + subPath: raw_data + ports: + - containerPort: ${USER_CONTAINER_PORT} + name: http + volumes: + {{- if $require_aws_config }} + - name: config-volume + configMap: + {{- if $service_template_aws_config_map_name }} + name: {{ $service_template_aws_config_map_name }} + {{- else }} + name: {{ $aws_config_map_name }} + {{- end }} + {{- end }} + - name: user-config + configMap: + name: ${RESOURCE_NAME} + - name: endpoint-config + configMap: + name: ${RESOURCE_NAME}-endpoint-config + - name: dshm + emptyDir: + medium: Memory + {{- if $config_values }} + - name: infra-service-config-volume + configMap: + name: {{ $launch_name }}-service-config + items: + - key: infra_service_config + path: config.yaml + {{- end }} + workerTemplate: + metadata: + labels: + app: ${RESOURCE_NAME} + role: worker + {{- $service_template_labels | nindent 14 }} + sidecar.istio.io/inject: "false" # Never inject istio for LWS, it screws up networking + version: v1 + annotations: + ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' + kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" + spec: + affinity: + {{- include "modelEngine.serviceTemplateAffinity" . | nindent 14 }} + {{- if eq $mode "async" }} # TODO + terminationGracePeriodSeconds: 1800 + {{- else }} + terminationGracePeriodSeconds: 600 + {{- end }} + {{- if $service_template_service_account_name }} + serviceAccount: {{ $service_template_service_account_name }} + {{- else }} + serviceAccount: {{ $launch_name }} + {{- end }} + {{- with $node_selector }} + nodeSelector: + {{- toYaml . | nindent 14 }} + {{- end }} + {{- if eq $device "gpu" }} + {{- if empty $node_selector }} + nodeSelector: + {{- end }} + k8s.amazonaws.com/accelerator: ${GPU_TYPE} + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + {{- end }} + priorityClassName: ${PRIORITY} + containers: + - name: lws-worker + image: ${IMAGE} + imagePullPolicy: IfNotPresent + command: ${WORKER_COMMAND} + env: ${WORKER_ENV} + resources: + requests: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + limits: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + volumeMounts: + {{- if $require_aws_config }} + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + {{- end }} + - mountPath: /dev/shm + name: dshm + {{- if $mount_infra_config }} + - name: infra-service-config-volume + mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} + {{- end }} + - name: user-config + mountPath: /app/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /app/endpoint_config + subPath: raw_data + ports: + - containerPort: ${USER_CONTAINER_PORT} + name: http + volumes: + {{- if $require_aws_config }} + - name: config-volume + configMap: + {{- if $service_template_aws_config_map_name }} + name: {{ $service_template_aws_config_map_name }} + {{- else }} + name: {{ $aws_config_map_name }} + {{- end }} + {{- end }} + - name: user-config + configMap: + name: ${RESOURCE_NAME} + - name: endpoint-config + configMap: + name: ${RESOURCE_NAME}-endpoint-config + - name: dshm + emptyDir: + medium: Memory + {{- if $config_values }} + - name: infra-service-config-volume + configMap: + name: {{ $launch_name }}-service-config + items: + - key: infra_service_config + path: config.yaml + {{- end }} + {{- end }} # mode + {{- end }} # device + service.yaml: |- + apiVersion: v1 + kind: Service + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + type: ${SERVICE_TYPE} + selector: + app: ${RESOURCE_NAME} + ports: + - port: 80 + targetPort: ${SERVICE_TARGET_PORT} + protocol: TCP + name: http + ${NODE_PORT_DICT} + lws-service.yaml: |- + apiVersion: v1 + kind: Service + metadata: + name: ${SERVICE_NAME_OVERRIDE} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + type: ${SERVICE_TYPE} + selector: + app: ${RESOURCE_NAME} + role: leader + ports: + - port: 80 + targetPort: ${SERVICE_TARGET_PORT} + protocol: TCP + name: http + ${NODE_PORT_DICT} + {{- if .Values.virtualservice.enabled }} + virtual-service.yaml: |- + apiVersion: networking.istio.io/v1alpha3 + kind: VirtualService + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + hosts: + - ${RESOURCE_NAME}.${DNS_HOST_DOMAIN} + gateways: + - default/internal-gateway + http: + - route: + - destination: + host: "${RESOURCE_NAME}.${NAMESPACE}.svc.cluster.local" + port: + number: 80 + {{- end }} + {{- if .Values.destinationrule.enabled }} + destination-rule.yaml: |- + apiVersion: networking.istio.io/v1beta1 + kind: DestinationRule + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + host: "${RESOURCE_NAME}.${NAMESPACE}.svc.cluster.local" + trafficPolicy: + loadBalancer: + simple: LEAST_REQUEST + {{- end }} + lws-service-entry.yaml: |- + apiVersion: networking.istio.io/v1beta1 + kind: ServiceEntry + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + hosts: + - "${SERVICE_NAME_OVERRIDE}.${NAMESPACE}.svc.cluster.local" + location: MESH_EXTERNAL + ports: + - number: 80 + name: http + protocol: HTTP + resolution: NONE + vertical-pod-autoscaler.yaml: |- + apiVersion: "autoscaling.k8s.io/v1" + kind: VerticalPodAutoscaler + metadata: + name: ${RESOURCE_NAME} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + targetRef: + apiVersion: "apps/v1" + kind: Deployment + name: ${RESOURCE_NAME} + updatePolicy: + updateMode: "Auto" + resourcePolicy: + containerPolicies: + - containerName: istio-proxy + mode: "Off" + - containerName: main + minAllowed: + cpu: 100m + memory: 128Mi + maxAllowed: + cpu: ${CPUS} + memory: ${MEMORY} + controlledResources: ["cpu", "memory"] + pod-disruption-budget.yaml: |- + apiVersion: policy/v1 + kind: PodDisruptionBudget + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + {{- $service_template_labels | nindent 8 }} + spec: + maxUnavailable: 50% + selector: + matchLabels: + app: ${RESOURCE_NAME} + batch-job-orchestration-job.yaml: |- + apiVersion: batch/v1 + kind: Job + metadata: + name: ${RESOURCE_NAME} + labels: + {{- $job_template_labels | nindent 8 }} + spec: + backoffLimit: 0 + activeDeadlineSeconds: ${BATCH_JOB_MAX_RUNTIME} + ttlSecondsAfterFinished: ${BATCH_JOB_TTL_SECONDS_AFTER_FINISHED} + template: + metadata: + labels: + {{- $job_template_labels | nindent 12 }} + sidecar.istio.io/inject: "false" + version: v1 + annotations: + ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:{{ $env }}", "launch_job_id:${JOB_ID}"]}]' + cluster-autoscaler.kubernetes.io/safe-to-evict: "false" + spec: + restartPolicy: Never + {{- with $node_selector }} + nodeSelector: + {{- toYaml . | nindent 12 }} + {{- end }} + serviceAccountName: {{ $launch_name }} + {{- if $require_aws_config }} + volumes: + - name: config-volume + configMap: + name: {{ $aws_config_map_name }} + {{- end }} + containers: + - name: main + image: {{ $gateway_repository }}:${GIT_TAG} + env: + - name: DD_SERVICE + value: ${RESOURCE_NAME} + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" + {{- $env_vars := $service_env | fromYaml }} + {{- range $env_var := index $env_vars "env" }} + {{- $env_var_name := index $env_var "name" }} + {{- if ne $env_var_name "DD_SERVICE" }} + {{- tuple $env_var | toYaml | nindent 16 }} + {{- end }} + {{- end }} + imagePullPolicy: IfNotPresent + command: + - dumb-init + - -- + {{- if $enable_datadog }} + - ddtrace-run + {{- end }} + args: + - python + - -m + - model_engine_server.entrypoints.start_batch_job_orchestration + - --job-id + - ${JOB_ID} + - --owner + - ${OWNER} + - --input-path + - ${INPUT_LOCATION} + - --serialization-format + - ${SERIALIZATION_FORMAT} + - --timeout-seconds + - "${BATCH_JOB_TIMEOUT}" + resources: + # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. + requests: + cpu: 1 + memory: 8Gi + ephemeral-storage: 10Gi + limits: + cpu: 4 + memory: 32Gi + ephemeral-storage: 30Gi + {{- if $require_aws_config }} + volumeMounts: + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + {{- end }} + {{- range $device := tuple "cpu" "gpu" }} + docker-image-batch-job-{{- $device }}.yaml: |- + apiVersion: batch/v1 + kind: Job + metadata: + name: ${RESOURCE_NAME} + labels: + {{- $job_template_labels | nindent 8 }} + spec: + backoffLimit: 0 + activeDeadlineSeconds: ${BATCH_JOB_MAX_RUNTIME} + ttlSecondsAfterFinished: ${BATCH_JOB_TTL_SECONDS_AFTER_FINISHED} + completions: ${BATCH_JOB_NUM_WORKERS} + parallelism: ${BATCH_JOB_NUM_WORKERS} + completionMode: "Indexed" + template: + metadata: + labels: + {{- $job_template_labels | nindent 12 }} + sidecar.istio.io/inject: "false" + version: v1 + annotations: + ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:{{ $env }}", "launch_job_id:${JOB_ID}"]}]' + cluster-autoscaler.kubernetes.io/safe-to-evict: "false" + spec: + restartPolicy: Never + {{- with $node_selector }} + nodeSelector: + {{- toYaml . | nindent 12 }} + {{- end }} + {{- if eq $device "gpu" }} + {{- if empty $node_selector }} + nodeSelector: + {{- end }} + k8s.amazonaws.com/accelerator: ${GPU_TYPE} + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + {{- end }} + {{- if $service_template_service_account_name }} + serviceAccountName: {{ $service_template_service_account_name }} + {{- else }} + serviceAccountName: {{ $launch_name }} + {{- end }} + volumes: + {{- if $require_aws_config }} + - name: config-volume + configMap: + name: {{ $aws_config_map_name }} + {{- end }} + - name: workdir + emptyDir: {} + - name: dshm + emptyDir: + medium: Memory + containers: + - name: main + image: ${IMAGE} + env: + - name: DD_SERVICE + value: ${RESOURCE_NAME} + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" + {{- $env_vars := $service_env | fromYaml }} + {{- range $env_var := index $env_vars "env" }} + {{- $env_var_name := index $env_var "name" }} + {{- if ne $env_var_name "DD_SERVICE" }} + {{- tuple $env_var | toYaml | nindent 16 }} + {{- end }} + {{- end }} + imagePullPolicy: IfNotPresent + command: ${COMMAND} + resources: + # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. + requests: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + limits: + {{- if eq $device "gpu" }} + nvidia.com/gpu: ${GPUS} + {{- end }} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + volumeMounts: + {{- if $require_aws_config }} + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + {{- end }} + - name: workdir + mountPath: ${MOUNT_PATH} + - mountPath: /dev/shm + name: dshm + initContainers: + - name: input-downloader + image: {{ $gateway_repository }}:${GIT_TAG} + env: + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" + command: + - python + - -m + - model_engine_server.entrypoints.start_docker_image_batch_job_init_container + - ${INPUT_LOCATION} + - --remote-file + - ${S3_FILE} + - --local-file + - ${LOCAL_FILE_NAME} + - --file-contents-b64encoded + - ${FILE_CONTENTS_B64ENCODED} + resources: + requests: + cpu: 1 + memory: 1Gi + limits: + cpu: 1 + memory: 1Gi + volumeMounts: + {{- if $require_aws_config }} + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + {{- end }} + - name: workdir + mountPath: ${MOUNT_PATH} + {{- end }} + {{- range $device := .Values.imageCache.devices }} + {{- $device_node_selector := index $device "nodeSelector" }} + {{- $device_tolerations := index $device "tolerations" }} + image-cache-{{- index $device "name" }}.yaml: |- + apiVersion: apps/v1 + kind: DaemonSet + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + team: infra + product: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/service: ${RESOURCE_NAME} + spec: + selector: + matchLabels: + app: ${RESOURCE_NAME} + version: v1 + updateStrategy: + type: RollingUpdate + template: + metadata: + labels: + app: ${RESOURCE_NAME} + team: infra + product: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/service: ${RESOURCE_NAME} + version: v1 + sidecar.istio.io/inject: "false" + spec: + {{- if $device_node_selector }} + {{- with $device_node_selector }} + nodeSelector: + {{- toYaml . | nindent 12 }} + {{- end }} + {{- end }} + {{- if $device_tolerations }} + {{- with $device_tolerations }} + tolerations: + {{- toYaml . | nindent 12 }} + {{- end }} + {{- end }} + containers: + - image: public.ecr.aws/docker/library/busybox:latest + imagePullPolicy: IfNotPresent + name: busybox + command: ["/bin/sh", "-ec", "while : ; do sleep 30 ; done"] + terminationGracePeriodSeconds: 0 + {{- end }} + cron-trigger.yaml: |- + apiVersion: batch/v1 + kind: CronJob + metadata: + name: ${NAME} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + launch_trigger_id: ${TRIGGER_ID} + tags.datadoghq.com/service: ${TRIGGER_ID} + spec: + schedule: "${CRON_SCHEDULE}" + successfulJobsHistoryLimit: 0 + failedJobsHistoryLimit: 0 + jobTemplate: + spec: + backoffLimit: 0 + activeDeadlineSeconds: ${BATCH_CURL_JOB_ACTIVE_DEADLINE_SECONDS} + template: + metadata: + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + launch_trigger_id: ${TRIGGER_ID} + tags.datadoghq.com/service: ${TRIGGER_ID} + spec: + containers: + - name: ${NAME} + image: curlimages/curl:7.72.0 + imagePullPolicy: IfNotPresent + command: + - curl + - -X + - 'POST' + - '${HOST}/v1/docker-image-batch-jobs' + - -H + - 'accept: application/json' + - -H + - 'Content-Type: application/json' + - -d + - '{ "docker_image_batch_job_bundle_id": "${DOCKER_IMAGE_BATCH_JOB_BUNDLE_ID}", "job_config": ${JOB_CONFIG}, "labels": ${JOB_METADATA} }' + - -u + - '${OWNER}:' + restartPolicy: Never diff --git a/charts/llm-engine/templates/llm_engine_init_job.yaml b/charts/model-engine/templates/spellbook_init_job.yaml similarity index 60% rename from charts/llm-engine/templates/llm_engine_init_job.yaml rename to charts/model-engine/templates/spellbook_init_job.yaml index 1892d087..ed23f4e6 100644 --- a/charts/llm-engine/templates/llm_engine_init_job.yaml +++ b/charts/model-engine/templates/spellbook_init_job.yaml @@ -1,10 +1,10 @@ -{{- if .Values.secrets.kubernetesDatabaseSecretName }} +{{- if and (.Values.secrets.kubernetesDatabaseSecretName) (.Values.spellbook.enabled) }} apiVersion: batch/v1 kind: Job metadata: - name: {{ include "llmEngine.fullname" . }}-init-job + name: {{ include "modelEngine.fullname" . }}-spellbook-setup labels: - {{- include "llmEngine.labels" . | nindent 4 }} + {{- include "modelEngine.labels" . | nindent 4 }} annotations: "helm.sh/hook": post-install "helm.sh/hook-weight": "0" @@ -16,7 +16,7 @@ spec: metadata: labels: sidecar.istio.io/inject: "false" - {{- include "llmEngine.labels" . | nindent 8 }} + {{- include "modelEngine.labels" . | nindent 8 }} spec: restartPolicy: Never {{- with .Values.imagePullSecrets }} @@ -24,7 +24,7 @@ spec: {{- toYaml . | nindent 8 }} {{- end }} containers: - - name: {{ include "llmEngine.fullname" . }} + - name: {{ include "modelEngine.fullname" . }} image: "{{ .Values.image.gatewayRepository }}:{{ .Values.tag}}" imagePullPolicy: {{ .Values.image.pullPolicy }} command: @@ -33,13 +33,13 @@ spec: args: - python - -m - - server.llm_engine_server.entrypoints.init_llm_engine_models + - model_engine_server.entrypoints.init_spellbook_models - --gateway-url - - 'http://{{- include "llmEngine.fullname" . }}.{{ .Release.Namespace }}:{{ .Values.service.port }}' - {{- include "llmEngine.serviceEnv" . | indent 10 }} - {{- include "llmEngine.volumeMounts" . | indent 10 }} - serviceAccountName: {{ include "llmEngine.fullname" . }} - {{- include "llmEngine.volumes" . | indent 6 }} + - '{{- include "modelEngine.gatewayurl" . }}' + {{- include "modelEngine.serviceEnvGitTagFromHelmVar" . | indent 10 }} + {{- include "modelEngine.volumeMounts" . | indent 10 }} + serviceAccountName: {{ include "modelEngine.fullname" . }} + {{- include "modelEngine.volumes" . | indent 6 }} {{- with .Values.nodeSelector }} nodeSelector: {{- toYaml . | nindent 8 }} diff --git a/charts/model-engine/templates/trigger_authentication.yaml b/charts/model-engine/templates/trigger_authentication.yaml new file mode 100644 index 00000000..088dee94 --- /dev/null +++ b/charts/model-engine/templates/trigger_authentication.yaml @@ -0,0 +1,24 @@ +{{- if .Values.azure }} +apiVersion: keda.sh/v1alpha1 +kind: TriggerAuthentication +metadata: + name: azure-workload-identity + namespace: {{ .Values.config.values.launch.endpoint_namespace }} +spec: + podIdentity: + provider: azure-workload + identityId: {{ .Values.azure.client_id }} +{{- else if .Values.redis.enableAuth }} +apiVersion: keda.sh/v1alpha1 +kind: TriggerAuthentication +metadata: + name: keda-trigger-auth-redis-secret + namespace: {{ .Values.config.values.launch.endpoint_namespace }} +spec: + awsSecretManager: + podIdentity: + provider: aws + secrets: + - parameter: password + name: {{ .Values.redis.kedaSecretName }} +{{- end }} diff --git a/charts/model-engine/values.yaml b/charts/model-engine/values.yaml new file mode 100644 index 00000000..b8e60da0 --- /dev/null +++ b/charts/model-engine/values.yaml @@ -0,0 +1,18 @@ +dd_trace_enabled: true +spellbook: + enabled: false +redis: + auth: + enableTLS: false + enableAuth: false + kedaSecretName: "" + unsafeSsl: false +db: + runDbInitScript: false + runDbMigrationScript: false +balloonConfig: + reserveHighPriority: true +balloonNodeSelector: + node-lifecycle: normal +nodeSelector: + node-lifecycle: normal diff --git a/charts/model-engine/values_circleci.yaml b/charts/model-engine/values_circleci.yaml new file mode 100644 index 00000000..9096c76f --- /dev/null +++ b/charts/model-engine/values_circleci.yaml @@ -0,0 +1,316 @@ +# This is a YAML-formatted file. + +replicaCount: + gateway: 1 + cacher: 1 + builder: 1 + +balloons: + - acceleratorName: nvidia-ampere-a10 + replicaCount: 0 + - acceleratorName: nvidia-ampere-a100 + replicaCount: 0 + - acceleratorName: cpu + replicaCount: 0 + - acceleratorName: nvidia-tesla-t4 + replicaCount: 0 + - acceleratorName: nvidia-hopper-h100 + replicaCount: 0 + +# tag needs to be set dynamically every time. Usually it is set to the SHA1 hash of the git +# commit from which the image was built. +# tag: +context: circleci +image: + gatewayRepository: model-engine + builderRepository: model-engine + cacherRepository: model-engine + forwarderRepository: model-engine + pullPolicy: IfNotPresent + +# serviceIdentifier: + +secrets: + kubernetesDatabaseSecretName: model-engine-postgres-credentials + + +service: + type: ClusterIP + port: 80 + +virtualservice: + enabled: true + annotations: { } + hostDomains: + - example.com + gateways: + - default/internal-gateway + +hostDomain: + prefix: http:// + +destinationrule: + enabled: true + annotations: { } + +autoscaling: + horizontal: + enabled: false + minReplicas: 1 + maxReplicas: 10 + targetConcurrency: 30 + vertical: + enabled: false + minAllowed: + cpu: 100m + memory: 128Mi + maxAllowed: + cpu: 10 + memory: 8Gi + updateMode: Auto + prewarming: + enabled: false + +celery_autoscaler: + enabled: false + +podDisruptionBudget: + enabled: true + minAvailable: 1 + +resources: + requests: + cpu: 2 + +nodeSelector: null + +balloonNodeSelector: null + +tolerations: [ ] + +affinity: { } + +config: + values: + infra: + cloud_provider: aws + k8s_cluster_name: minikube + dns_host_domain: localhost + default_region: us-west-2 + ml_account_id: "$CIRCLECI_AWS_ACCOUNT_ID" + docker_repo_prefix: "$CIRCLECI_AWS_ACCOUNT_ID.dkr.ecr.us-west-2.amazonaws.com" + redis_host: redis-message-broker-master.default + s3_bucket: "$CIRCLECI_AWS_S3_BUCKET" + profile_ml_worker: "default" + profile_ml_inference_worker: "default" + launch: + # Endpoint config + # K8s namespace the endpoints will be created in + endpoint_namespace: model-engine + model_primitive_host: none + + # Asynchronous endpoints + sqs_profile: default + sqs_queue_policy_template: > + { + "Version": "2012-10-17", + "Id": "__default_policy_ID", + "Statement": [ + { + "Sid": "__owner_statement", + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::$CIRCLECI_AWS_ACCOUNT_ID:root" + }, + "Action": "sqs:*", + "Resource": "arn:aws:sqs:us-west-2:$CIRCLECI_AWS_ACCOUNT_ID:${queue_name}" + }, + { + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::$CIRCLECI_AWS_ACCOUNT_ID:role/default" + }, + "Action": "sqs:*", + "Resource": "arn:aws:sqs:us-west-2:$CIRCLECI_AWS_ACCOUNT_ID:${queue_name}" + }, + { + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::$CIRCLECI_AWS_ACCOUNT_ID:role/ml_llm_engine" + }, + "Action": "sqs:*", + "Resource": "arn:aws:sqs:us-west-2:$CIRCLECI_AWS_ACCOUNT_ID:${queue_name}" + } + ] + } + sqs_queue_tag_template: > + { + "Spellbook-Serve-Endpoint-Id": "${endpoint_id}", + "Spellbook-Serve-Endpoint-Name": "${endpoint_name}", + "Spellbook-Serve-Endpoint-Created-By": "${endpoint_created_by}" + } + + billing_queue_arn: none + cache_redis_aws_url: redis://redis-message-broker-master.default/15 + cloud_file_llm_fine_tune_repository: "s3://$CIRCLECI_AWS_S3_BUCKET/fine_tune_repository" + dd_trace_enabled: false + istio_enabled: true + sensitive_log_mode: false + tgi_repository: "text-generation-inference" + vllm_repository: "vllm" + lightllm_repository: "lightllm" + tensorrt_llm_repository: "tensorrt-llm" + batch_inference_vllm_repository: "llm-engine/batch-infer-vllm" + user_inference_base_repository: "launch/inference" + user_inference_pytorch_repository: "hosted-model-inference/async-pytorch" + user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu" + docker_image_layer_cache_repository: "kaniko-cache" + hf_user_fine_tuned_weights_prefix: "s3://$CIRCLECI_AWS_S3_BUCKET/model-weights" + +# Service Account +serviceAccount: + annotations: + "eks.amazonaws.com/role-arn": arn:aws:iam::$CIRCLECI_AWS_ACCOUNT_ID:role/default + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-2" + namespaces: + - default + - model-engine + +aws: + configMap: + name: default-config + create: false + mountPath: /opt/.aws/config + profileName: default + s3WriteProfileName: default + +forwarder: + forceUseIPv4: true + +triton: + image: + repository: nvidia/tritonserver + tag: latest + +serviceTemplate: + securityContext: + capabilities: + drop: + - all + mountInfraConfig: true + serviceAccountName: default + awsConfigMapName: default-config + +imageCache: + devices: + - name: cpu + nodeSelector: + cpu-only: "true" + - name: a10 + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-ampere-a10 + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - name: a100 + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-ampere-a100 + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - name: t4 + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-tesla-t4 + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - name: h100 + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-hopper-h100 + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + +celeryBrokerType: redis + +datadog: + enabled: false + +recommendedHardware: + byGpuMemoryGb: + - gpu_memory_le: 24 + cpus: 10 + gpus: 1 + memory: 24Gi + storage: 80Gi + gpu_type: nvidia-ampere-a10 + nodes_per_worker: 1 + - gpu_memory_le: 48 + cpus: 20 + gpus: 2 + memory: 48Gi + storage: 80Gi + gpu_type: nvidia-ampere-a10 + nodes_per_worker: 1 + - gpu_memory_le: 96 + cpus: 40 + gpus: 4 + memory: 96Gi + storage: 96Gi + gpu_type: nvidia-ampere-a10 + nodes_per_worker: 1 + - gpu_memory_le: 180 + cpus: 20 + gpus: 2 + memory: 160Gi + storage: 160Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - gpu_memory_le: 320 + cpus: 40 + gpus: 4 + memory: 320Gi + storage: 320Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - gpu_memory_le: 640 + cpus: 80 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - gpu_memory_le: 1280 + cpus: 80 + gpus: 8 + memory: 800Gi + storage: 900Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 2 + byModelName: + - name: llama-3-8b-instruct-262k + cpus: 20 + gpus: 2 + memory: 40Gi + storage: 40Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - name: deepseek-coder-v2 + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - name: deepseek-coder-v2-instruct + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 \ No newline at end of file diff --git a/charts/model-engine/values_sample.yaml b/charts/model-engine/values_sample.yaml new file mode 100644 index 00000000..f7e1fe58 --- /dev/null +++ b/charts/model-engine/values_sample.yaml @@ -0,0 +1,400 @@ +# This is a YAML-formatted file. + +# tag [required] is the LLM Engine docker image tag +tag: 60ac144c55aad971cdd7f152f4f7816ce2fb7d2f +# context is a user-specified deployment tag. Can be used to +context: production +image: + # gatewayRepository [required] is the docker repository to pull the LLM Engine gateway image from + gatewayRepository: public.ecr.aws/b2z8n5q1/model-engine + # builderRepository [required] is the docker repository to pull the LLM Engine endpoint builder image from + builderRepository: public.ecr.aws/b2z8n5q1/model-engine + # cacherRepository [required] is the docker repository to pull the LLM Engine cacher image from + cacherRepository: public.ecr.aws/b2z8n5q1/model-engine + # forwarderRepository [required] is the docker repository to pull the LLM Engine forwarder image from + forwarderRepository: public.ecr.aws/b2z8n5q1/model-engine + # pullPolicy is the docker image pull policy + pullPolicy: Always + +secrets: + # kubernetesDatabaseSecretName or cloudDatabaseSecretName [required] + # is the name of the secret that contains the database credentials + kubernetesDatabaseSecretName: llm-engine-postgres-credentials + +# Azure Key Vault name to pull secrets from +keyvaultName: llm-engine-keyvault + +db: + runDbInitScript: false + runDbMigrationScript: false + +# serviceAccount [required] specifies the service account for LLM Engine server deployments (e.g gateway, cache, and builder deployments). +serviceAccount: + annotations: + # eks.amazonaws.com/role-arn [required] is the ARN of the IAM role that the service account will assume + eks.amazonaws.com/role-arn: arn:aws:iam::000000000000:role/k8s-main-llm-engine + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-2" + namespaces: [] + +imageBuilderServiceAccount: + create: true + annotations: + # eks.amazonaws.com/role-arn [required] is the ARN of the IAM role that the image builder service account will assume. Needs to have ecr permissions + eks.amazonaws.com/role-arn: arn:aws:iam::000000000000:role/k8s-main-llm-engine-image-builder + # Reads from serviceAccount.namespaces to determine which namespaces to create the image builder service account in + +# service specifies the service configuration for the main LLM Engine server. Users should setup their own ingress controller to expose the service. +service: + type: ClusterIP + port: 80 + +# virtualservice specifies the configuration of an Istio VirtualService +virtualservice: + enabled: true + annotations: { } + hostDomains: + - llm-engine.domain.com + gateways: + - default/internal-gateway + +hostDomain: + prefix: http:// + +# destinationrule specifies the configuration of an Istio DestinationRule +destinationrule: + enabled: true + annotations: { } + +# replicaCount specifies the amount of replica pods for each deployment +replicaCount: + # gateway is the main LLM Engine server deployment + gateway: 2 + # cacher is the kubernetes state caching deployment + cacher: 1 + # builder is the endpoint builder deployment + builder: 1 + +balloonConfig: + # If set to true, only high priority pods can preempt balloons. Otherwise, all pods can preempt balloons. + reserveHighPriority: true + +balloons: + # A low priority pod deployment for A10 GPU nodes + - acceleratorName: nvidia-ampere-a10 + replicaCount: 0 + # A low priority pod deployment for A100 GPU nodes + - acceleratorName: nvidia-ampere-a100 + replicaCount: 0 + # A low priority pod deployment for CPU nodes + - acceleratorName: cpu + replicaCount: 0 + # A low priority pod deployment for T4 GPU nodes + - acceleratorName: nvidia-tesla-t4 + replicaCount: 0 + # A low priority pod deployment for H100 GPU nodes + - acceleratorName: nvidia-hopper-h100 + replicaCount: 0 + gpuCount: 4 + +# autoscaling is the autoscaling configuration for LLM Engine server deployments (e.g gateway, cache, and builder deployments) +autoscaling: + horizontal: + enabled: true + minReplicas: 2 + maxReplicas: 10 + targetConcurrency: 50 + vertical: + enabled: false + prewarming: + enabled: false + +# for async endpoints, Celery autoscaler scales the number of pods based on number of requests +# num_shards is number of instances of the autoscaler +celery_autoscaler: + enabled: true + num_shards: 3 + +podDisruptionBudget: + enabled: true + minAvailable: 1 + +# resources specify the k8s resources for LLM Engine server deployments (e.g gateway, cache, and builder deployments) +resources: + requests: + cpu: 2 +# nodeSelector specifies the node selector for LLM Engine server deployments (e.g gateway, cache, and builder deployments) +nodeSelector: { } +# tolerations specifies the tolerations for LLM Engine server deployments (e.g gateway, cache, and builder deployments) +tolerations: [ ] +# affinity specifies the affinity for LLM Engine server deployments (e.g gateway, cache, and builder deployments) +affinity: { } + +# aws specifies the AWS configurations (by configMap) for LLM Engine server deployments +aws: + configMap: + name: default-config + create: true + profileName: default + +# serviceTemplate specifies additional flags for model endpoints +serviceTemplate: + securityContext: + capabilities: + drop: + - all + mountInfraConfig: true + # createServiceAccount/serviceAccountName/serviceAccountAnnotations specify whether to create a serviceAccount for + # inference pods. Assumes the inference pods run in a separate namespace to the LLM Engine control plane. + createServiceAccount: true + serviceAccountName: model-engine + serviceAccountAnnotations: + eks.amazonaws.com/role-arn: arn:aws:iam::000000000000:role/llm-engine + "helm.sh/hook": pre-install,pre-upgrade + "helm.sh/hook-weight": "-2" + +# config specifes the `data` field of the service config map +config: + values: + infra: + # cloud_provider [required]; either "aws" or "azure" + cloud_provider: aws + # k8s_cluster_name [required] is the name of the k8s cluster + k8s_cluster_name: main_cluster + # dns_host_domain [required] is the domain name of the k8s cluster + dns_host_domain: llm-engine.domain.com + # default_region [required] is the default AWS region for various resources (e.g ECR) + default_region: us-east-1 + # aws_account_id [required] is the AWS account ID for various resources (e.g ECR) + ml_account_id: "000000000000" + # docker_repo_prefix [required] is the prefix for AWS ECR repositories + docker_repo_prefix: "000000000000.dkr.ecr.us-east-1.amazonaws.com" + # redis_host [required if redis_aws_secret_name not present] is the hostname of the redis cluster you wish to connect + redis_host: llm-engine-prod-cache.use1.cache.amazonaws.com + # redis_aws_secret_name [optional] is the AWS secret that contains the connection info of the Redis cluster. + # The information provided should be as follows: + # scheme: either redis:// or rediss://, will default to redis:// + # auth_token (optional): an auth token for the Redis cluster + # host: the hostname of the Redis cluster + # port: the port of the Redis cluster + # query_params (optional): additional query parameters for the Redis cluster, will default to "" + # The url will be built as follows: + # {scheme}{host}:{port}/{db_index}{query_params} if auth_token is not provided, + # {scheme}:{auth_token}@{host}:{port}/{db_index}{query_params} if auth_token is provided + # db_index will be filled in by LLM Engine. + # This secret must be accessible by the default LLM Engine AWS role + # e.g. what is set by profile_ml_worker if provided + # redis_aws_secret_name: sample-prod/redis-credentials + # s3_bucket [required] is the S3 bucket you wish to connect + s3_bucket: "llm-engine" + # DB engine configs (This is SQLAlchemy heavy) + db_engine_pool_size: 10 + db_engine_max_overflow: 10 + db_engine_echo: false + db_engine_echo_pool: false + db_engine_disconnect_strategy: "pessimistic" + # prometheus_server_address [optional, required if you want scale from zero for sync/streaming endpoints] + # is the address of the Prometheus server to query for endpoint metrics + prometheus_server_address: "http://prometheus-server.istio-system.svc.cluster.local:80" + launch: + # endpoint_namespace [required] is K8s namespace the endpoints will be created in + endpoint_namespace: llm-engine + # cache_redis_aws_url is the full url for the redis cluster you wish to connect, + # cache_redis_azure_host is the redis cluster host when using cloud_provider azure + # cache_redis_aws_secret_name is an AWS secret that contains the Redis credentials. + # It has a field "cache-url" with the full URL of the Redis cluster (including db number). + # Other fields are ignored; e.g. you can use the secret for multiple purposes. + # This secret must be accessible by the default LLM Engine AWS role + # exactly one of cache_redis_aws_url, cache_redis_azure_host, or cache_redis_aws_secret_name must be provided + cache_redis_aws_url: redis://llm-engine-prod-cache.use1.cache.amazonaws.com:6379/15 + cache_redis_azure_host: llm-engine-cache.redis.cache.windows.net:6380 + cache_redis_aws_secret_name: sample-prod/redis-credentials + # s3_file_llm_fine_tuning_job_repository [required] is the S3 URI for the S3 bucket/key that you wish to save fine-tuned assests + s3_file_llm_fine_tuning_job_repository: "s3://llm-engine/llm-ft-job-repository" + # dd_trace_enabled specifies whether to enable datadog tracing, datadog must be installed in the cluster + dd_trace_enabled: false + istio_enabled: true + sensitive_log_mode: false + + # Asynchronous endpoints configs (coming soon) + sqs_profile: default + # sqs_queue_policy_template [required] is the IAM policy template for SQS queue for async endpoints. + sqs_queue_policy_template: > + { + "Version": "2012-10-17", + "Id": "__default_policy_ID", + "Statement": [ + { + "Sid": "__owner_statement", + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::000000000000:root" + }, + "Action": "sqs:*", + "Resource": "arn:aws:sqs:us-east-1:000000000000:${queue_name}" + }, + { + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::000000000000:role/k8s-main-llm-engine" + }, + "Action": "sqs:*", + "Resource": "arn:aws:sqs:us-east-1:000000000000:${queue_name}" + } + ] + } + + sqs_queue_tag_template: > + { + "Spellbook-Serve-Endpoint-Id": "${endpoint_id}", + "Spellbook-Serve-Endpoint-Name": "${endpoint_name}", + "Spellbook-Serve-Endpoint-Created-By": "${endpoint_created_by}" + } + billing_queue_arn: "unused" + model_primitive_host: "unused" + hf_user_fine_tuned_weights_prefix: "s3://llm-engine/fine_tuned_weights" + sensitive_log_mode: false + tgi_repository: "text-generation-inference" + vllm_repository: "vllm" + lightllm_repository: "lightllm" + tensorrt_llm_repository: "tensorrt-llm" + batch_inference_vllm_repository: "llm-engine/batch-infer-vllm" + user_inference_base_repository: "launch/inference" + user_inference_pytorch_repository: "launch/inference/pytorch" + user_inference_tensorflow_repository: "launch/inference/tf" + docker_image_layer_cache_repository: "launch-docker-build-cache" + +# Triton enhanced endpoints (coming soon) +triton: + image: + repository: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv + tag: e83eccbc8959f90ebbe4bda618b61ec6ee2d8394-triton + +# imageCache specifies the image cache configuration for faster endpoint auto-scaling (coming soon) +imageCache: + devices: + - name: cpu + nodeSelector: + cpu-only: "true" + - name: a10 + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-ampere-a10 + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - name: a100 + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-ampere-a100 + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - name: t4 + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-tesla-t4 + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - name: h100 + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-hopper-h100 + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - name: h100-1g20gb + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-hopper-h100-1g20gb + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + - name: h100-3g40gb + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-hopper-h100-3g40gb + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + +# celeryBrokerType specifies the celery broker type for async endpoints, either "sqs" or "elasticache" +celeryBrokerType: sqs + +datadog: + enabled: false + +recommendedHardware: + byGpuMemoryGb: + - gpu_memory_le: 24 + cpus: 10 + gpus: 1 + memory: 24Gi + storage: 80Gi + gpu_type: nvidia-ampere-a10 + nodes_per_worker: 1 + - gpu_memory_le: 48 + cpus: 20 + gpus: 2 + memory: 48Gi + storage: 80Gi + gpu_type: nvidia-ampere-a10 + nodes_per_worker: 1 + - gpu_memory_le: 96 + cpus: 40 + gpus: 4 + memory: 96Gi + storage: 96Gi + gpu_type: nvidia-ampere-a10 + nodes_per_worker: 1 + - gpu_memory_le: 180 + cpus: 20 + gpus: 2 + memory: 160Gi + storage: 160Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - gpu_memory_le: 320 + cpus: 40 + gpus: 4 + memory: 320Gi + storage: 320Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - gpu_memory_le: 640 + cpus: 80 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - gpu_memory_le: 640 + cpus: 80 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 2 + byModelName: + - name: llama-3-8b-instruct-262k + cpus: 20 + gpus: 2 + memory: 40Gi + storage: 40Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - name: deepseek-coder-v2 + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - name: deepseek-coder-v2-instruct + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 \ No newline at end of file diff --git a/clients/python/README.md b/clients/python/README.md index 9d4b7dbc..e9f6d289 100644 --- a/clients/python/README.md +++ b/clients/python/README.md @@ -15,18 +15,18 @@ pip install scale-llm-engine ### Usage If you are using LLM Engine, you can get your API key from -[https://spellbook.scale.com/settings](https://spellbook.scale.com/settings). +[https://spellbook.scale.com/settings](https://spellbook.scale.com/settings). Set the `SCALE_API_KEY` environment variable to your API key. If you are using your own infrastructure, you can set the -`LLM_ENGINE_SERVE_BASE_PATH` environment variable to the base URL of your +`LLM_ENGINE_BASE_PATH` environment variable to the base URL of your self-hosted `llmengine` endpoint. ```python from llmengine import Completion response = Completion.create( - model="llama-7b", + model="llama-2-7b", prompt="Hello, my name is", max_new_tokens=10, temperature=0.2, diff --git a/clients/python/llmengine/__init__.py b/clients/python/llmengine/__init__.py index 97324d59..206b405b 100644 --- a/clients/python/llmengine/__init__.py +++ b/clients/python/llmengine/__init__.py @@ -12,42 +12,133 @@ # See the License for the specific language governing permissions and # limitations under the License. -__version__ = "0.0.0.beta3" +__version__ = "0.0.0beta44" +import os from typing import Sequence +import requests from llmengine.completion import Completion from llmengine.data_types import ( + BatchCompletionsJob, + BatchCompletionsJobStatus, + BatchCompletionsModelConfig, CancelFineTuneResponse, + ChatCompletionV2Request, + ChatCompletionV2Response, CompletionOutput, CompletionStreamOutput, CompletionStreamResponse, + CompletionStreamV1Request, + CompletionStreamV1Response, CompletionSyncResponse, + CompletionSyncV1Request, + CompletionSyncV1Response, + CreateBatchCompletionsModelConfig, + CreateBatchCompletionsRequest, + CreateBatchCompletionsRequestContent, + CreateBatchCompletionsResponse, + CreateBatchCompletionsV1Request, + CreateBatchCompletionsV1RequestContent, + CreateBatchCompletionsV1Response, + CreateBatchCompletionsV2ModelConfig, + CreateBatchCompletionsV2Request, + CreateBatchCompletionsV2RequestContent, + CreateBatchCompletionsV2Response, CreateFineTuneRequest, CreateFineTuneResponse, + DeleteFileResponse, DeleteLLMEndpointResponse, + FilteredChatCompletionV2Request, + FilteredCompletionV2Request, + GetFileContentResponse, + GetFileResponse, GetFineTuneResponse, GetLLMEndpointResponse, + ListFilesResponse, ListFineTunesResponse, ListLLMEndpointsResponse, + ModelDownloadRequest, + ModelDownloadResponse, + UploadFileResponse, + VLLMEndpointAdditionalArgs, ) +from llmengine.file import File from llmengine.fine_tuning import FineTune from llmengine.model import Model __all__: Sequence[str] = ( + "BatchCompletionsJob", + "CreateBatchCompletionsV2Response", + "FilteredCompletionV2Request", + "FilteredChatCompletionV2Request", + "BatchCompletionsJobStatus", + "CompletionSyncV1Request", + "CompletionSyncV1Response", + "CompletionStreamV1Request", + "CompletionStreamV1Response", "CancelFineTuneResponse", + "ChatCompletionV2Request", + "ChatCompletionV2Response", + "VLLMEndpointAdditionalArgs", "Completion", "CompletionOutput", "CompletionStreamOutput", "CompletionStreamResponse", "CompletionSyncResponse", + "CreateBatchCompletionsModelConfig", + "CreateBatchCompletionsRequest", + "CreateBatchCompletionsRequestContent", + "CreateBatchCompletionsResponse", + "CreateBatchCompletionsV1Request", + "CreateBatchCompletionsV1RequestContent", + "CreateBatchCompletionsV1Response", + "CreateBatchCompletionsV2Request", + "CreateBatchCompletionsV2RequestContent", + "CreateBatchCompletionsV2ModelConfig", + "BatchCompletionsModelConfig", "CreateFineTuneRequest", "CreateFineTuneResponse", + "DeleteFileResponse", "DeleteLLMEndpointResponse", + "ModelDownloadRequest", + "ModelDownloadResponse", + "GetFileContentResponse", + "File", "FineTune", + "GetFileResponse", "GetFineTuneResponse", "GetLLMEndpointResponse", + "ListFilesResponse", "ListFineTunesResponse", "ListLLMEndpointsResponse", "Model", + "UploadFileResponse", ) + + +def check_version(): + try: + current_version = __version__ + response = requests.get("https://pypi.org/pypi/scale-llm-engine/json") + latest_version = response.json()["info"]["version"] + + if current_version != latest_version: + print( + f"A newer version ({latest_version}) of 'scale-llm-engine' is available. Please upgrade!" + ) + print("To upgrade, run: pip install --upgrade scale-llm-engine") + print( + "Don't want to see this message? Set the environment variable 'LLM_ENGINE_DISABLE_VERSION_CHECK' to 'true'." + ) + except requests.RequestException: + # Handle exceptions related to the request (like timeouts, connection errors, etc.) + print( + "Failed to check for the most recent llm-engine package version. Please check your internet connection." + ) + except Exception: + print("Something went wrong with checking for the most recent llm-engine package version.") + + +if not os.environ.get("LLM_ENGINE_DISABLE_VERSION_CHECK"): + check_version() diff --git a/clients/python/llmengine/api_engine.py b/clients/python/llmengine/api_engine.py index fb0ec830..05d298cd 100644 --- a/clients/python/llmengine/api_engine.py +++ b/clients/python/llmengine/api_engine.py @@ -3,26 +3,48 @@ import json import os from functools import wraps +from io import BufferedReader from typing import Any, AsyncIterable, Dict, Iterator, Optional +from urllib.parse import urljoin import requests -from aiohttp import ClientSession, ClientTimeout +from aiohttp import BasicAuth, ClientSession, ClientTimeout from llmengine.errors import parse_error -SCALE_API_KEY = os.getenv("SCALE_API_KEY") -SPELLBOOK_API_URL = "https://api.spellbook.scale.com" -LLM_ENGINE_BASE_PATH = os.getenv("LLM_ENGINE_BASE_PATH", SPELLBOOK_API_URL) +SPELLBOOK_API_URL = "https://api.spellbook.scale.com/llm-engine/" DEFAULT_TIMEOUT: int = 10 +base_path = None +api_key = None + + +def set_base_path(path): + global base_path + base_path = path + + +def get_base_path() -> str: + if base_path is not None: + return base_path + return os.getenv("LLM_ENGINE_BASE_PATH", SPELLBOOK_API_URL) + + +def set_api_key(key): + global api_key + api_key = key + def get_api_key() -> str: - return SCALE_API_KEY or "root" + if api_key is not None: + return api_key + env_api_key = os.getenv("SCALE_API_KEY") + return env_api_key or "root" def assert_self_hosted(func): @wraps(func) def inner(*args, **kwargs): - if SPELLBOOK_API_URL == LLM_ENGINE_BASE_PATH: + if SPELLBOOK_API_URL == get_base_path(): raise ValueError("This feature is only available for self-hosted users.") return func(*args, **kwargs) @@ -32,18 +54,22 @@ def inner(*args, **kwargs): class APIEngine: @classmethod def validate_api_key(cls): - if SPELLBOOK_API_URL == LLM_ENGINE_BASE_PATH and not SCALE_API_KEY: + if SPELLBOOK_API_URL == get_base_path() and not get_api_key(): raise ValueError( "You must set SCALE_API_KEY in your environment to to use the LLM Engine API." ) @classmethod - def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]: + def _get( + cls, resource_name: str, timeout: int, headers: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.get( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), timeout=timeout, - headers={"x-api-key": api_key}, + headers={"x-api-key": api_key, **(headers or {})}, + auth=(api_key, ""), ) if response.status_code != 200: raise parse_error(response.status_code, response.content) @@ -52,14 +78,20 @@ def _get(cls, resource_name: str, timeout: int) -> Dict[str, Any]: @classmethod def put( - cls, resource_name: str, data: Optional[Dict[str, Any]], timeout: int + cls, + resource_name: str, + data: Optional[Dict[str, Any]], + timeout: int, + headers: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.put( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), json=data, timeout=timeout, - headers={"x-api-key": api_key}, + headers={"x-api-key": api_key, **(headers or {})}, + auth=(api_key, ""), ) if response.status_code != 200: raise parse_error(response.status_code, response.content) @@ -67,12 +99,16 @@ def put( return payload @classmethod - def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]: + def _delete( + cls, resource_name: str, timeout: int, headers: Optional[Dict[str, str]] = None + ) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.delete( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), timeout=timeout, - headers={"x-api-key": api_key}, + headers={"x-api-key": api_key, **(headers or {})}, + auth=(api_key, ""), ) if response.status_code != 200: raise parse_error(response.status_code, response.content) @@ -80,13 +116,21 @@ def _delete(cls, resource_name: str, timeout: int) -> Dict[str, Any]: return payload @classmethod - def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Dict[str, Any]: + def post_sync( + cls, + resource_name: str, + data: Dict[str, Any], + timeout: int, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() response = requests.post( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), json=data, timeout=timeout, - headers={"x-api-key": api_key}, + auth=(api_key, ""), + headers={"x-api-key": api_key, **(headers or {})}, ) if response.status_code != 200: raise parse_error(response.status_code, response.content) @@ -95,14 +139,20 @@ def post_sync(cls, resource_name: str, data: Dict[str, Any], timeout: int) -> Di @classmethod def post_stream( - cls, resource_name: str, data: Dict[str, Any], timeout: int + cls, + resource_name: str, + data: Dict[str, Any], + timeout: int, + headers: Optional[Dict[str, str]] = None, ) -> Iterator[Dict[str, Any]]: + base_path = get_base_path() api_key = get_api_key() response = requests.post( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), + urljoin(base_path, resource_name), json=data, timeout=timeout, - headers={"x-api-key": api_key}, + headers={"x-api-key": api_key, **(headers or {})}, + auth=(api_key, ""), stream=True, ) if response.status_code != 200: @@ -124,17 +174,44 @@ def post_stream( except json.JSONDecodeError: raise ValueError(f"Invalid JSON payload: {payload_data}") + @classmethod + def post_file( + cls, + resource_name: str, + files: Dict[str, BufferedReader], + timeout: int, + headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + base_path = get_base_path() + api_key = get_api_key() + response = requests.post( + urljoin(base_path, resource_name), + files=files, + timeout=timeout, + headers={"x-api-key": api_key, **(headers or {})}, + auth=(api_key, ""), + ) + if response.status_code != 200: + raise parse_error(response.status_code, response.content) + payload = response.json() + return payload + @classmethod async def apost_sync( - cls, resource_name: str, data: Dict[str, Any], timeout: int + cls, + resource_name: str, + data: Dict[str, Any], + timeout: int, + headers: Optional[Dict[str, str]] = None, ) -> Dict[str, Any]: + base_path = get_base_path() api_key = get_api_key() async with ClientSession( - timeout=ClientTimeout(timeout), headers={"x-api-key": api_key} + timeout=ClientTimeout(timeout), + headers={"x-api-key": api_key, **(headers or {})}, + auth=BasicAuth(api_key, ""), ) as session: - async with session.post( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), json=data - ) as resp: + async with session.post(urljoin(base_path, resource_name), json=data) as resp: if resp.status != 200: raise parse_error(resp.status, await resp.read()) payload = await resp.json() @@ -142,15 +219,20 @@ async def apost_sync( @classmethod async def apost_stream( - cls, resource_name: str, data: Dict[str, Any], timeout: int + cls, + resource_name: str, + data: Dict[str, Any], + timeout: int, + headers: Optional[Dict[str, str]] = None, ) -> AsyncIterable[Dict[str, Any]]: + base_path = get_base_path() api_key = get_api_key() async with ClientSession( - timeout=ClientTimeout(timeout), headers={"x-api-key": api_key} + timeout=ClientTimeout(timeout), + headers={"x-api-key": api_key, **(headers or {})}, + auth=BasicAuth(api_key, ""), ) as session: - async with session.post( - os.path.join(LLM_ENGINE_BASE_PATH, resource_name), json=data - ) as resp: + async with session.post(urljoin(base_path, resource_name), json=data) as resp: if resp.status != 200: raise parse_error(resp.status, await resp.read()) async for byte_payload in resp.content: diff --git a/clients/python/llmengine/completion.py b/clients/python/llmengine/completion.py index 6c6f2039..8f972e91 100644 --- a/clients/python/llmengine/completion.py +++ b/clients/python/llmengine/completion.py @@ -1,13 +1,27 @@ -from typing import AsyncIterable, Iterator, Union +from typing import Any, AsyncIterable, Dict, Iterator, List, Optional, Union, cast from llmengine.api_engine import APIEngine from llmengine.data_types import ( + BatchCompletionContent, CompletionStreamResponse, CompletionStreamV1Request, CompletionSyncResponse, CompletionSyncV1Request, + CpuSpecificationType, + CreateBatchCompletionsModelConfig, + CreateBatchCompletionsV1Request, + CreateBatchCompletionsV1RequestContent, + CreateBatchCompletionsV1Response, + CreateBatchCompletionsV2Request, + CreateBatchCompletionsV2Response, + GpuType, + StorageSpecificationType, + ToolConfig, ) +COMPLETION_TIMEOUT = 300 +HTTP_TIMEOUT = 60 + class Completion(APIEngine): """ @@ -29,8 +43,20 @@ async def acreate( prompt: str, max_new_tokens: int = 20, temperature: float = 0.2, - timeout: int = 10, + stop_sequences: Optional[List[str]] = None, + return_token_log_probs: Optional[bool] = False, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + include_stop_str_in_output: Optional[bool] = None, + guided_json: Optional[Dict[str, Any]] = None, + guided_regex: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + guided_grammar: Optional[str] = None, + timeout: int = COMPLETION_TIMEOUT, stream: bool = False, + request_headers: Optional[Dict[str, str]] = None, ) -> Union[CompletionSyncResponse, AsyncIterable[CompletionStreamResponse]]: """ Creates a completion for the provided prompt and parameters asynchronously (with `asyncio`). @@ -57,8 +83,51 @@ async def acreate( [Model Zoo](../../model_zoo) for information on each supported model's context length. temperature (float): - What sampling temperature to use, in the range `(0, 1]`. Higher values like 0.8 will make the output + What sampling temperature to use, in the range `[0, 1]`. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + When temperature is 0 [greedy search](https://huggingface.co/docs/transformers/generation_strategies#greedy-search) is used. + + stop_sequences (Optional[List[str]]): + One or more sequences where the API will stop generating tokens for the current completion. + + return_token_log_probs (Optional[bool]): + Whether to return the log probabilities of generated tokens. + When True, the response will include a list of tokens and their log probabilities. + + presence_penalty (Optional[float]): + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + https://platform.openai.com/docs/guides/gpt/parameter-details + Range: [0.0, 2.0]. Higher values encourage the model to use new tokens. + + frequency_penalty (Optional[float]): + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + https://platform.openai.com/docs/guides/gpt/parameter-details + Range: [0.0, 2.0]. Higher values encourage the model to use new tokens. + + top_k (Optional[int]): + Integer that controls the number of top tokens to consider. + Range: [1, infinity). -1 means consider all tokens. + + top_p (Optional[float]): + Float that controls the cumulative probability of the top tokens to consider. + Range: (0.0, 1.0]. 1.0 means consider all tokens. + + include_stop_str_in_output (Optional[bool]): + Whether to include the stop sequence in the output. Default to False. + + guided_json (Optional[Dict[str, Any]]): + If specified, the output will follow the JSON schema. For examples see https://json-schema.org/learn/miscellaneous-examples. + + guided_regex (Optional[str]): + If specified, the output will follow the regex pattern. + + guided_choice (Optional[List[str]]): + If specified, the output will be exactly one of the choices. + + guided_grammar (Optional[str]): + If specified, the output will follow the context-free grammar provided. timeout (int): Timeout in seconds. This is the maximum amount of time you are willing to wait for a response. @@ -78,7 +147,7 @@ async def acreate( async def main(): response = await Completion.acreate( - model="llama-7b", + model="llama-2-7b", prompt="Hello, my name is", max_new_tokens=10, temperature=0.2, @@ -108,7 +177,7 @@ async def main(): async def main(): stream = await Completion.acreate( - model="llama-7b", + model="llama-2-7b", prompt="why is the sky blue?", max_new_tokens=5, temperature=0.2, @@ -141,6 +210,7 @@ async def _acreate_stream( resource_name=f"v1/llm/completions-stream?model_endpoint_name={model}", data=data, timeout=timeout, + headers=request_headers, ) async for chunk in response: yield CompletionStreamResponse.parse_obj(chunk) @@ -150,6 +220,17 @@ async def _acreate_stream( prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature, + stop_sequences=stop_sequences, + return_token_log_probs=return_token_log_probs, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + top_k=top_k, + top_p=top_p, + include_stop_str_in_output=include_stop_str_in_output, + guided_json=guided_json, + guided_regex=guided_regex, + guided_choice=guided_choice, + guided_grammar=guided_grammar, timeout=timeout, ) @@ -161,11 +242,25 @@ async def _acreate_sync(**kwargs) -> CompletionSyncResponse: resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}", data=data, timeout=timeout, + headers=request_headers, ) return CompletionSyncResponse.parse_obj(response) return await _acreate_sync( - prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature + prompt=prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + stop_sequences=stop_sequences, + return_token_log_probs=return_token_log_probs, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + top_k=top_k, + top_p=top_p, + include_stop_str_in_output=include_stop_str_in_output, + guided_json=guided_json, + guided_regex=guided_regex, + guided_choice=guided_choice, + guided_grammar=guided_grammar, ) @classmethod @@ -175,8 +270,20 @@ def create( prompt: str, max_new_tokens: int = 20, temperature: float = 0.2, - timeout: int = 10, + stop_sequences: Optional[List[str]] = None, + return_token_log_probs: Optional[bool] = False, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + top_k: Optional[int] = None, + top_p: Optional[float] = None, + include_stop_str_in_output: Optional[bool] = None, + guided_json: Optional[Dict[str, Any]] = None, + guided_regex: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + guided_grammar: Optional[str] = None, + timeout: int = COMPLETION_TIMEOUT, stream: bool = False, + request_headers: Optional[Dict[str, str]] = None, ) -> Union[CompletionSyncResponse, Iterator[CompletionStreamResponse]]: """ Creates a completion for the provided prompt and parameters synchronously. @@ -204,8 +311,51 @@ def create( [Model Zoo](../../model_zoo) for information on each supported model's context length. temperature (float): - What sampling temperature to use, in the range `(0, 1]`. Higher values like 0.8 will make the output + What sampling temperature to use, in the range `[0, 1]`. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + When temperature is 0 [greedy search](https://huggingface.co/docs/transformers/generation_strategies#greedy-search) is used. + + stop_sequences (Optional[List[str]]): + One or more sequences where the API will stop generating tokens for the current completion. + + return_token_log_probs (Optional[bool]): + Whether to return the log probabilities of generated tokens. + When True, the response will include a list of tokens and their log probabilities. + + presence_penalty (Optional[float]): + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + https://platform.openai.com/docs/guides/gpt/parameter-details + Range: [0.0, 2.0]. Higher values encourage the model to use new tokens. + + frequency_penalty (Optional[float]): + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + https://platform.openai.com/docs/guides/gpt/parameter-details + Range: [0.0, 2.0]. Higher values encourage the model to use new tokens. + + top_k (Optional[int]): + Integer that controls the number of top tokens to consider. + Range: [1, infinity). -1 means consider all tokens. + + top_p (Optional[float]): + Float that controls the cumulative probability of the top tokens to consider. + Range: (0.0, 1.0]. 1.0 means consider all tokens. + + include_stop_str_in_output (Optional[bool]): + Whether to include the stop sequence in the output. Default to False. + + guided_json (Optional[Dict[str, Any]]): + If specified, the output will follow the JSON schema. + + guided_regex (Optional[str]): + If specified, the output will follow the regex pattern. + + guided_choice (Optional[List[str]]): + If specified, the output will be exactly one of the choices. + + guided_grammar (Optional[str]): + If specified, the output will follow the context-free grammar provided. timeout (int): Timeout in seconds. This is the maximum amount of time you are willing to wait for a response. @@ -224,7 +374,7 @@ def create( from llmengine import Completion response = Completion.create( - model="llama-7b", + model="llama-2-7b", prompt="Hello, my name is", max_new_tokens=10, temperature=0.2, @@ -250,7 +400,7 @@ def create( from llmengine import Completion stream = Completion.create( - model="llama-7b", + model="llama-2-7b", prompt="why is the sky blue?", max_new_tokens=5, temperature=0.2, @@ -279,21 +429,321 @@ def _create_stream(**kwargs): resource_name=f"v1/llm/completions-stream?model_endpoint_name={model}", data=data_stream, timeout=timeout, + headers=request_headers, ) for chunk in response_stream: yield CompletionStreamResponse.parse_obj(chunk) return _create_stream( - prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature + prompt=prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + stop_sequences=stop_sequences, + return_token_log_probs=return_token_log_probs, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + top_k=top_k, + top_p=top_p, + include_stop_str_in_output=include_stop_str_in_output, + guided_json=guided_json, + guided_regex=guided_regex, + guided_choice=guided_choice, + guided_grammar=guided_grammar, ) else: data = CompletionSyncV1Request( - prompt=prompt, max_new_tokens=max_new_tokens, temperature=temperature + prompt=prompt, + max_new_tokens=max_new_tokens, + temperature=temperature, + stop_sequences=stop_sequences, + return_token_log_probs=return_token_log_probs, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + top_k=top_k, + top_p=top_p, + include_stop_str_in_output=include_stop_str_in_output, + guided_json=guided_json, + guided_regex=guided_regex, + guided_choice=guided_choice, + guided_grammar=guided_grammar, ).dict() response = cls.post_sync( resource_name=f"v1/llm/completions-sync?model_endpoint_name={model}", data=data, timeout=timeout, + headers=request_headers, ) return CompletionSyncResponse.parse_obj(response) + + @classmethod + def batch_create( + cls, + output_data_path: str, + model_config: CreateBatchCompletionsModelConfig, + content: Optional[BatchCompletionContent] = None, + input_data_path: Optional[str] = None, + data_parallelism: int = 1, + max_runtime_sec: int = 24 * 3600, + labels: Optional[Dict[str, str]] = None, + priority: Optional[str] = None, + use_v2: bool = False, + tool_config: Optional[ToolConfig] = None, + cpus: Optional[CpuSpecificationType] = None, + gpus: Optional[int] = None, + memory: Optional[StorageSpecificationType] = None, + gpu_type: Optional[GpuType] = None, + storage: Optional[StorageSpecificationType] = None, + request_headers: Optional[Dict[str, str]] = None, + ) -> Union[CreateBatchCompletionsV1Response, CreateBatchCompletionsV2Response]: + """ + Creates a batch completion for the provided input data. The job runs offline and does not depend on an existing model endpoint. + + Prompts can be passed in from an input file, or as a part of the request. + + Args: + output_data_path (str): + The path to the output file. The output file will be a JSON file containing the completions. + + model_config (CreateBatchCompletionsModelConfig): + The model configuration to use for the batch completion. + + content (Optional[CreateBatchCompletionsRequestContent]): + The content to use for the batch completion. Either one of `content` or `input_data_path` must be provided. + + input_data_path (Optional[str]): + The path to the input file. The input file should be a JSON file with data of type `BatchCompletionsRequestContent`. Either one of `content` or `input_data_path` must be provided. + + data_parallelism (int): + The number of parallel jobs to run. Data will be evenly distributed to the jobs. Defaults to 1. + + priority (str): + Priority of the batch inference job. Default to None. + + max_runtime_sec (int): + The maximum runtime of the batch completion in seconds. Defaults to 24 hours. + + use_v2 (bool): + Whether to use the v2 batch completion API. Defaults to False. + + tool_config (Optional[ToolConfig]): + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + Currently only Python code evaluator is supported. + Python code context starts with "\`\`\`python\\n" and ends with "\\n>>>\\n", data before "\\n\`\`\`\\n" and content end will be replaced by the Python execution results. + Please format prompts accordingly and provide examples so LLMs could properly generate Python code. + + Returns: + response (CreateBatchCompletionsResponse): The response containing the job id. + + === "Batch completions with prompts in the request" + ```python + from llmengine import Completion + from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent + + response = Completion.batch_create( + output_data_path="s3://my-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + content=CreateBatchCompletionsRequestContent( + prompts=["What is deep learning", "What is a neural network"], + max_new_tokens=10, + temperature=0.0 + ) + ) + print(response.json()) + ``` + + === "Batch completions with prompts in a file and with 2 parallel jobs" + ```python + from llmengine import Completion + from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent + + # Store CreateBatchCompletionsRequestContent data into input file "s3://my-input-path" + + response = Completion.batch_create( + input_data_path="s3://my-input-path", + output_data_path="s3://my-output-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + data_parallelism=2 + ) + print(response.json()) + ``` + + === "Batch completions with prompts and use tool" + ```python + from llmengine import Completion + from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent, ToolConfig + + # Store CreateBatchCompletionsRequestContent data into input file "s3://my-input-path" + + response = Completion.batch_create( + input_data_path="s3://my-input-path", + output_data_path="s3://my-output-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + data_parallelism=2, + tool_config=ToolConfig( + name="code_evaluator", + ) + ) + print(response.json()) + ``` + + === "V2 Batch completions with prompts in the request" + ```python + from llmengine import Completion + from llmengine.data_types import CreateBatchCompletionsModelConfig, FilteredChatCompletionV2Request + + model_config = CreateBatchCompletionsModelConfig( + model="gemma-2-2b-it", + checkpoint_path="s3://path-to-checkpoint", + ) + + content = { + "messages": [ + { + "role": "user", + "content": "What is a good place for travel in the US?", + }, + {"role": "assistant", "content": "California."}, + {"role": "user", "content": "What can I do in California?"}, + ], + "logprobs": True, + } + + response = Completion.batch_create( + output_data_path="testoutput", + model_config=model_config, + content=[FilteredChatCompletionV2Request(**content)], + use_v2=True, + labels={"team": "my-team", "product": "my-product"}, + ) + + print(response.json()) + """ + labels = labels if labels else model_config.labels + if use_v2: + data = CreateBatchCompletionsV2Request( + model_config=model_config, + content=content, + input_data_path=input_data_path, + output_data_path=output_data_path, + data_parallelism=data_parallelism, + labels=labels, + max_runtime_sec=max_runtime_sec, + tool_config=tool_config, + priority=priority, + cpus=cpus, + gpus=gpus, + memory=memory, + gpu_type=gpu_type, + storage=storage, + ).dict() + response = cls.post_sync( + resource_name="v2/batch-completions", + data=data, + timeout=HTTP_TIMEOUT, + headers=request_headers, + ) + return CreateBatchCompletionsV2Response.parse_obj(response) + else: + if input_data_path is None and not isinstance( + content, CreateBatchCompletionsV1RequestContent + ): + raise ValueError( + "Either input_data_path or content must be provided. If content is provided, it must be of type CreateBatchCompletionsV1RequestContent." + ) + + content = cast(Optional[CreateBatchCompletionsV1RequestContent], content) + data = CreateBatchCompletionsV1Request( + model_config=model_config, + content=content, + input_data_path=input_data_path, + output_data_path=output_data_path, + data_parallelism=data_parallelism, + max_runtime_sec=max_runtime_sec, + tool_config=tool_config, + ).dict() + response = cls.post_sync( + resource_name="v1/llm/batch-completions", + data=data, + timeout=HTTP_TIMEOUT, + headers=request_headers, + ) + return CreateBatchCompletionsV1Response.parse_obj(response) + + @classmethod + def get_batch_completion( + cls, + job_id: str, + request_headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """ + Get the status of a batch completion job. + + Args: + job_id (str): + The job id of the batch completion job. + + Returns: + response (Dict[str, Any]): The response containing the job status. + + === "Get batch completion status" + ```python + from llmengine import Completion + + response = Completion.get_batch_completion(job_id="job_id") + print( + f"Current job status for {job_id} is {job.status}" + ) + ``` + """ + response = cls._get( + resource_name=f"v2/batch-completions/{job_id}", + timeout=HTTP_TIMEOUT, + headers=request_headers, + ) + return response + + @classmethod + def cancel_batch_completion( + cls, + job_id: str, + request_headers: Optional[Dict[str, str]] = None, + ) -> Dict[str, Any]: + """ + Cancel a batch completion job. + + Args: + job_id (str): + The job id of the batch completion job. + + Returns: + response (Dict[str, Any]): The response containing the job status. + + === "Cancel batch completion job" + ```python + from llmengine import Completion + + response = Completion.cancel_batch_completion(job_id="job-id") + print(response) + ``` + """ + response = cls.post_sync( + resource_name=f"v2/batch-completions/{job_id}/actions/cancel", + data={}, + timeout=HTTP_TIMEOUT, + headers=request_headers, + ) + return response diff --git a/clients/python/llmengine/data_types.py b/clients/python/llmengine/data_types.py deleted file mode 100644 index 90cd201c..00000000 --- a/clients/python/llmengine/data_types.py +++ /dev/null @@ -1,420 +0,0 @@ -""" -DTOs for LLM APIs. -""" -import datetime -from enum import Enum -from typing import Any, Dict, List, Literal, Optional, Union - -from pydantic import BaseModel, Field, HttpUrl - -CpuSpecificationType = Union[str, int, float] -StorageSpecificationType = Union[str, int, float] # TODO(phil): we can make this more specific. - - -class LLMInferenceFramework(str, Enum): - DEEPSPEED = "deepspeed" - TEXT_GENERATION_INFERENCE = "text_generation_inference" - - -class LLMSource(str, Enum): - HUGGING_FACE = "hugging_face" - - -class Quantization(str, Enum): - BITSANDBYTES = "bitsandbytes" - - -class GpuType(str, Enum): - """Lists allowed GPU types for LLMEngine.""" - - NVIDIA_TESLA_T4 = "nvidia-tesla-t4" - NVIDIA_AMPERE_A10 = "nvidia-ampere-a10" - NVIDIA_AMPERE_A100 = "nvidia-a100" - - -class ModelEndpointType(str, Enum): - ASYNC = "async" - SYNC = "sync" - STREAMING = "streaming" - - -class ModelEndpointStatus(str, Enum): - # Duplicates common/types::EndpointStatus, when refactor is done, delete the old type - # See EndpointStatus for status explanations - READY = "READY" - UPDATE_PENDING = "UPDATE_PENDING" - UPDATE_IN_PROGRESS = "UPDATE_IN_PROGRESS" - UPDATE_FAILED = "UPDATE_FAILED" - DELETE_IN_PROGRESS = "DELETE_IN_PROGRESS" - - -class CallbackBasicAuth(BaseModel): - kind: Literal["basic"] - username: str - password: str - - -class CallbackmTLSAuth(BaseModel): - kind: Literal["mtls"] - cert: str - key: str - - -class CallbackAuth(BaseModel): - __root__: Union[CallbackBasicAuth, CallbackmTLSAuth] = Field(..., discriminator="kind") - - -class ModelEndpointDeploymentState(BaseModel): - """ - This is the entity-layer class for the deployment settings related to a Model Endpoint. - """ - - min_workers: int = Field(..., ge=0) - max_workers: int = Field(..., ge=0) - per_worker: int = Field(..., gt=0) - available_workers: Optional[int] = Field(default=None, ge=0) - unavailable_workers: Optional[int] = Field(default=None, ge=0) - - -class ModelEndpointResourceState(BaseModel): - """ - This is the entity-layer class for the resource settings per worker of a Model Endpoint. - """ - - cpus: CpuSpecificationType # TODO(phil): try to use decimal.Decimal - gpus: int = Field(..., ge=0) - memory: StorageSpecificationType - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] - - -class GetModelEndpointResponse(BaseModel): - id: str - name: str - endpoint_type: ModelEndpointType - destination: str - deployment_name: Optional[str] = Field(default=None) - metadata: Optional[Dict[str, Any]] = Field(default=None) # TODO: JSON type - bundle_name: str - status: ModelEndpointStatus - post_inference_hooks: Optional[List[str]] = Field(default=None) - default_callback_url: Optional[HttpUrl] = Field(default=None) - default_callback_auth: Optional[CallbackAuth] = Field(default=None) - labels: Optional[Dict[str, str]] = Field(default=None) - aws_role: Optional[str] = Field(default=None) - results_s3_bucket: Optional[str] = Field(default=None) - created_by: str - created_at: datetime.datetime - last_updated_at: datetime.datetime - deployment_state: Optional[ModelEndpointDeploymentState] = Field(default=None) - resource_state: Optional[ModelEndpointResourceState] = Field(default=None) - num_queued_items: Optional[int] = Field(default=None) - public_inference: Optional[bool] = Field(default=None) - - -class CreateLLMEndpointRequest(BaseModel): - name: str - - # LLM specific fields - model_name: str - source: LLMSource = LLMSource.HUGGING_FACE - inference_framework: LLMInferenceFramework = LLMInferenceFramework.DEEPSPEED - inference_framework_image_tag: str - num_shards: int - """ - Number of shards to distribute the model onto GPUs. - """ - - # General endpoint fields - metadata: Dict[str, Any] # TODO: JSON type - post_inference_hooks: Optional[List[str]] - endpoint_type: ModelEndpointType = ModelEndpointType.SYNC - cpus: CpuSpecificationType - gpus: int - memory: StorageSpecificationType - gpu_type: GpuType - storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] - min_workers: int - max_workers: int - per_worker: int - labels: Dict[str, str] - prewarm: Optional[bool] - high_priority: Optional[bool] - default_callback_url: Optional[HttpUrl] - default_callback_auth: Optional[CallbackAuth] - public_inference: Optional[bool] = True # LLM endpoints are public by default. - - -class CreateLLMEndpointResponse(BaseModel): - endpoint_creation_task_id: str - - -class GetLLMEndpointResponse(BaseModel): - """ - Response object for retrieving a Model. - """ - - id: Optional[str] = Field( - default=None, description="(For self-hosted users) The autogenerated ID of the model." - ) - """(For self-hosted users) The autogenerated ID of the model.""" - - name: str = Field( - description="The name of the model. Use this for making inference requests to the model." - ) - """The name of the model. Use this for making inference requests to the model.""" - - model_name: Optional[str] = Field( - default=None, - description="(For self-hosted users) For fine-tuned models, the base model. For base models, this will be the same as `name`.", - ) - """(For self-hosted users) For fine-tuned models, the base model. For base models, this will be the same as `name`.""" - - source: LLMSource = Field(description="The source of the model, e.g. Hugging Face.") - """The source of the model, e.g. Hugging Face.""" - - inference_framework: LLMInferenceFramework = Field( - description="The inference framework used by the model." - ) - """(For self-hosted users) The inference framework used by the model.""" - - inference_framework_tag: Optional[str] = Field( - default=None, - description="(For self-hosted users) The Docker image tag used to run the model.", - ) - """(For self-hosted users) The Docker image tag used to run the model.""" - - num_shards: Optional[int] = Field( - default=None, description="(For self-hosted users) The number of shards." - ) - """(For self-hosted users) The number of shards.""" - - quantize: Optional[Quantization] = Field( - default=None, description="(For self-hosted users) The quantization method." - ) - """(For self-hosted users) The quantization method.""" - - spec: Optional[GetModelEndpointResponse] = Field( - default=None, description="(For self-hosted users) Model endpoint details." - ) - """(For self-hosted users) Model endpoint details.""" - - -class ListLLMEndpointsResponse(BaseModel): - """ - Response object for listing Models. - """ - - model_endpoints: List[GetLLMEndpointResponse] = Field( - ..., - description="The list of models.", - ) - """ - A list of Models, represented as `GetLLMEndpointResponse`s. - """ - - -class DeleteLLMEndpointResponse(BaseModel): - """ - Response object for deleting a Model. - """ - - deleted: bool = Field(..., description="Whether deletion was successful.") - """ - Whether the deletion succeeded. - """ - - -class CompletionSyncV1Request(BaseModel): - """ - Request object for a synchronous prompt completion task. - """ - - prompt: str = Field(..., min_length=1) - max_new_tokens: int = Field(..., gt=0) - temperature: float = Field(..., gt=0.0) - - -class CompletionOutput(BaseModel): - """ - Represents the output of a completion request to a model. - """ - - text: str - """The text of the completion.""" - - num_completion_tokens: int - """Number of tokens in the completion.""" - - -class CompletionSyncResponse(BaseModel): - """ - Response object for a synchronous prompt completion. - """ - - request_id: str - """The unique ID of the corresponding Completion request. This `request_id` is generated on the server, and all logs - associated with the request are grouped by the `request_id`, which allows for easier troubleshooting of errors as - follows: - - * When running the *Scale-hosted* LLM Engine, please provide the `request_id` in any bug reports. - * When running the *self-hosted* LLM Engine, the `request_id` serves as a trace ID in your observability - provider.""" - - output: CompletionOutput - """Completion output.""" - - -class CompletionStreamV1Request(BaseModel): - """ - Request object for a streaming prompt completion. - """ - - prompt: str = Field(..., min_length=1) - max_new_tokens: int = Field(..., gt=0) - temperature: float = Field(..., gt=0.0) - - -class CompletionStreamOutput(BaseModel): - text: str - """The text of the completion.""" - - finished: bool - """Whether the completion is finished.""" - - num_completion_tokens: Optional[int] = None - """Number of tokens in the completion.""" - - -class CompletionStreamResponse(BaseModel): - """ - Response object for a stream prompt completion task. - """ - - request_id: str - """The unique ID of the corresponding Completion request. This `request_id` is generated on the server, and all logs - associated with the request are grouped by the `request_id`, which allows for easier troubleshooting of errors as - follows: - - * When running the *Scale-hosted* LLM Engine, please provide the `request_id` in any bug reports. - * When running the *self-hosted* LLM Engine, the `request_id` serves as a trace ID in your observability - provider.""" - - output: Optional[CompletionStreamOutput] = None - """Completion output.""" - - -class CreateFineTuneRequest(BaseModel): - """ - Request object for creating a FineTune. - """ - - model: str = Field(..., description="Identifier of base model to train from.") - """Identifier of base model to train from.""" - - training_file: str = Field( - ..., - description="Path to file of training dataset. Dataset must be a csv with columns 'prompt' and 'response'.", - ) - """Path to file of training dataset. Dataset must be a csv with columns 'prompt' and 'response'.""" - - validation_file: Optional[str] = Field( - default=None, - description="Path to file of validation dataset. Has the same format as training_file. If not provided, we will generate a split from the training dataset.", - ) - """Path to file of validation dataset. Has the same format as training_file. If not provided, we will generate a split from the training dataset.""" - - hyperparameters: Optional[Dict[str, Any]] = Field( - default=None, description="Hyperparameters to pass in to training job." - ) - """Hyperparameters to pass in to training job.""" - - suffix: Optional[str] = Field( - default=None, - description="Optional user-provided identifier suffix for the fine-tuned model.", - ) - """Optional user-provided identifier suffix for the fine-tuned model.""" - - -class CreateFineTuneResponse(BaseModel): - """ - Response object for creating a FineTune. - """ - - fine_tune_id: str = Field(..., description="ID of the created fine-tuning job.") - """ - The ID of the FineTune. - """ - - -class BatchJobStatus(str, Enum): - PENDING = "PENDING" - RUNNING = "RUNNING" - SUCCESS = "SUCCESS" - FAILURE = "FAILURE" - CANCELLED = "CANCELLED" - UNDEFINED = "UNDEFINED" - TIMEOUT = "TIMEOUT" - - -class GetFineTuneResponse(BaseModel): - """ - Response object for retrieving a FineTune. - """ - - fine_tune_id: str = Field(..., description="ID of the requested job.") - """ - The ID of the FineTune. - """ - - status: BatchJobStatus = Field(..., description="Status of the requested job.") - """ - The status of the FineTune job. - """ - - -class ListFineTunesResponse(BaseModel): - """ - Response object for listing FineTunes. - """ - - jobs: List[GetFineTuneResponse] = Field( - ..., description="List of fine-tuning jobs and their statuses." - ) - """ - A list of FineTunes, represented as `GetFineTuneResponse`s. - """ - - -class CancelFineTuneResponse(BaseModel): - """ - Response object for cancelling a FineTune. - """ - - success: bool = Field(..., description="Whether cancellation was successful.") - """ - Whether the cancellation succeeded. - """ - - -class LLMFineTuneEvent(BaseModel): - """ - Response object one FineTune event. - """ - - timestamp: Optional[float] = Field( - description="Timestamp of the event.", - default=None, - ) - message: str = Field(description="Message of the event.") - level: str = Field(description="Logging level of the event.") - - -class GetFineTuneEventsResponse(BaseModel): - """ - Response object for getting events for a FineTune. - """ - - events: List[LLMFineTuneEvent] = Field(..., description="List of fine-tuning events.") diff --git a/clients/python/llmengine/data_types/__init__.py b/clients/python/llmengine/data_types/__init__.py new file mode 100644 index 00000000..a666add4 --- /dev/null +++ b/clients/python/llmengine/data_types/__init__.py @@ -0,0 +1,26 @@ +""" +DTOs for LLM APIs. +""" + +from typing_extensions import TypeAlias + +from .batch_completion import * # noqa: F403 +from .chat_completion import * # noqa: F403 +from .completion import * # noqa: F403 +from .core import * # noqa: F403 +from .model_endpoints import * # noqa: F403 +from .rest import * # noqa: F403 +from .vllm import * # noqa: F403 + +# Alias for backwards compatibility +CreateBatchCompletionsRequestContent: TypeAlias = ( + CreateBatchCompletionsV1RequestContent # noqa: F405 +) +CreateBatchCompletionsRequest: TypeAlias = CreateBatchCompletionsV1Request # noqa: F405 +CreateBatchCompletionsResponse: TypeAlias = CreateBatchCompletionsV1Response # noqa: F405 +CreateBatchCompletionsModelConfig: TypeAlias = CreateBatchCompletionsV1ModelConfig # noqa: F405 + +CompletionSyncRequest: TypeAlias = CompletionSyncV1Request # noqa: F405 +CompletionSyncResponse: TypeAlias = CompletionSyncV1Response # noqa: F405 +CompletionStreamRequest: TypeAlias = CompletionStreamV1Request # noqa: F405 +CompletionStreamResponse: TypeAlias = CompletionStreamV1Response # noqa: F405 diff --git a/clients/python/llmengine/data_types/batch_completion.py b/clients/python/llmengine/data_types/batch_completion.py new file mode 100644 index 00000000..6935351f --- /dev/null +++ b/clients/python/llmengine/data_types/batch_completion.py @@ -0,0 +1,302 @@ +from enum import Enum +from typing import Dict, List, Optional, Union + +from typing_extensions import TypeAlias + +from .chat_completion import ChatCompletionV2Request, ChatCompletionV2Response +from .completion import CompletionOutput, CompletionV2Request, CompletionV2Response +from .pydantic_types import BaseModel, Field +from .rest import CpuSpecificationType, GpuType, StorageSpecificationType +from .vllm import VLLMModelConfig + + +# Common DTOs for batch completions +class ToolConfig(BaseModel): + """ + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + """ + + name: str + """ + Name of the tool to use for the batch inference. + """ + max_iterations: Optional[int] = 10 + """ + Maximum number of iterations to run the tool. + """ + execution_timeout_seconds: Optional[int] = 60 + """ + Maximum runtime of the tool in seconds. + """ + should_retry_on_error: Optional[bool] = True + """ + Whether to retry the tool on error. + """ + + +class BatchCompletionsModelConfig(VLLMModelConfig): + model: str = Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ) + + checkpoint_path: Optional[str] = Field( + default=None, description="Path to the checkpoint to load the model from." + ) + + num_shards: Optional[int] = Field( + default=1, + ge=1, + description=""" +Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config. +System may decide to use a different number than the given value. +""", + ) + + max_context_length: Optional[int] = Field( + default=None, + ge=1, + description="Maximum context length to use for the model. Defaults to the max allowed by the model", + ) + + seed: Optional[int] = Field(default=None, description="Random seed for the model.") + + response_role: Optional[str] = Field( + default=None, + description="Role of the response in the conversation. Only supported in chat completions.", + ) + + +class BatchCompletionsRequestBase(BaseModel): + input_data_path: Optional[str] = Field( + default=None, + description="Path to the input file. The input file should be a JSON file of type List[CreateBatchCompletionsRequestContent].", + ) + output_data_path: str = Field( + description="Path to the output file. The output file will be a JSON file of type List[CompletionOutput]." + ) + + labels: Dict[str, str] = Field( + default={}, description="Labels to attach to the batch inference job." + ) + + data_parallelism: Optional[int] = Field( + default=1, + ge=1, + le=64, + description="Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference.", + ) + + max_runtime_sec: Optional[int] = Field( + default=24 * 3600, + ge=1, + le=2 * 24 * 3600, + description="Maximum runtime of the batch inference in seconds. Default to one day.", + ) + + priority: Optional[str] = Field( + default=None, + description="Priority of the batch inference job. Default to None.", + ) + + tool_config: Optional[ToolConfig] = Field( + default=None, + description=""" +Configuration for tool use. +NOTE: this config is highly experimental and signature will change significantly in future iterations.""", + ) + + cpus: Optional[CpuSpecificationType] = Field( + default=None, description="CPUs to use for the batch inference." + ) + gpus: Optional[int] = Field( + default=None, description="Number of GPUs to use for the batch inference." + ) + memory: Optional[StorageSpecificationType] = Field( + default=None, description="Amount of memory to use for the batch inference." + ) + gpu_type: Optional[GpuType] = Field( + default=None, description="GPU type to use for the batch inference." + ) + storage: Optional[StorageSpecificationType] = Field( + default=None, description="Storage to use for the batch inference." + ) + nodes_per_worker: Optional[int] = Field( + default=None, description="Number of nodes per worker for the batch inference." + ) + + +# V1 DTOs for batch completions +CompletionV1Output = CompletionOutput + + +class CreateBatchCompletionsV1ModelConfig(BatchCompletionsModelConfig): + labels: Dict[str, str] = Field( + default={}, description="Labels to attach to the batch inference job." + ) + + +class CreateBatchCompletionsV1RequestContent(BaseModel): + prompts: List[str] + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. + """ + + +class CreateBatchCompletionsV1Request(BatchCompletionsRequestBase): + """ + Request object for batch completions. + """ + + content: Optional[CreateBatchCompletionsV1RequestContent] = None + """ + Either `input_data_path` or `content` needs to be provided. + When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. + """ + model_config: CreateBatchCompletionsV1ModelConfig = Field(alias="model_config") + """ + Model configuration for the batch inference. Hardware configurations are inferred. + """ + + +class CreateBatchCompletionsV1Response(BaseModel): + job_id: str + + +class FilteredCompletionV2Request(CompletionV2Request): + model: Optional[str] = None # type: ignore[assignment] + stream: Optional[bool] = False + + +class FilteredChatCompletionV2Request(ChatCompletionV2Request): + model: Optional[str] = None # type: ignore[assignment] + stream: Optional[bool] = False + + +# V2 DTOs for batch completions +CompletionRequest: TypeAlias = Union[FilteredCompletionV2Request, FilteredChatCompletionV2Request] +CompletionResponse: TypeAlias = Union[CompletionV2Response, ChatCompletionV2Response] +CreateBatchCompletionsV2RequestContent: TypeAlias = Union[ + List[FilteredCompletionV2Request], List[FilteredChatCompletionV2Request] +] +CreateBatchCompletionsV2ModelConfig: TypeAlias = BatchCompletionsModelConfig + +BatchCompletionContent = Union[ + CreateBatchCompletionsV1RequestContent, CreateBatchCompletionsV2RequestContent +] + + +class CreateBatchCompletionsV2Request(BatchCompletionsRequestBase): + """ + Request object for batch completions. + """ + + content: Optional[BatchCompletionContent] = Field( + default=None, + description=""" +Either `input_data_path` or `content` needs to be provided. +When input_data_path is provided, the input file should be a JSON file of type List[CreateBatchCompletionsRequestContent]. +""", + ) + + model_config: BatchCompletionsModelConfig = Field( + description="""Model configuration for the batch inference. Hardware configurations are inferred.""", + ) + + +class BatchCompletionsJobStatus(str, Enum): + Queued = "queued" + Running = "running" + Completed = "completed" + Failed = "failed" + Cancelled = "cancelled" + Unknown = "unknown" + + +class BatchCompletionsJob(BaseModel): + job_id: str + input_data_path: Optional[str] = Field( + default=None, + description="Path to the input file. The input file should be a JSON file of type List[CreateBatchCompletionsRequestContent].", + ) + output_data_path: str = Field( + description="Path to the output file. The output file will be a JSON file of type List[CompletionOutput]." + ) + + model_config: BatchCompletionsModelConfig = Field( + description="""Model configuration for the batch inference. Hardware configurations are inferred.""", + ) + + priority: Optional[str] = Field( + default=None, + description="Priority of the batch inference job. Default to None.", + ) + status: BatchCompletionsJobStatus + created_at: str + expires_at: str + completed_at: Optional[str] + metadata: Optional[Dict[str, str]] + + +CreateBatchCompletionsV2Response: TypeAlias = BatchCompletionsJob + + +class UpdateBatchCompletionsV2Request(BaseModel): + job_id: str = Field(description="ID of the batch completions job") + priority: Optional[str] = Field( + default=None, + description="Priority of the batch inference job. Default to None.", + ) + + +class UpdateBatchCompletionsV2Response(BatchCompletionsJob): + success: bool = Field(description="Whether the update was successful") + + +class CancelBatchCompletionsV2Request(BaseModel): + job_id: str = Field(description="ID of the batch completions job") + + +class CancelBatchCompletionsV2Response(BaseModel): + success: bool = Field(description="Whether the cancellation was successful") + + +class ListBatchCompletionV2Response(BaseModel): + jobs: List[BatchCompletionsJob] + + +class GetBatchCompletionV2Response(BaseModel): + job: BatchCompletionsJob diff --git a/clients/python/llmengine/data_types/chat_completion.py b/clients/python/llmengine/data_types/chat_completion.py new file mode 100644 index 00000000..251b2bd5 --- /dev/null +++ b/clients/python/llmengine/data_types/chat_completion.py @@ -0,0 +1,50 @@ +from typing import Optional, TypeAlias + +from .core import StreamError +from .gen.openai import ( + CreateChatCompletionRequest, + CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, +) +from .pydantic_types import BaseModel, Field +from .vllm import VLLMChatCompletionAdditionalParams + +# Fields that are a part of OpenAI spec but are not supported by model engine +UNSUPPORTED_FIELDS = ["service_tier"] + + +class ChatCompletionV2Request(CreateChatCompletionRequest, VLLMChatCompletionAdditionalParams): + model: str = Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ) + + stream: Optional[bool] = Field( + False, + description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ) + + top_k: Optional[int] = Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ) + + include_stop_str_in_output: Optional[bool] = Field( + None, description="Whether to include the stop strings in output text." + ) + + +ChatCompletionV2SyncResponse: TypeAlias = CreateChatCompletionResponse +ChatCompletionV2StreamSuccessChunk: TypeAlias = CreateChatCompletionStreamResponse + + +class ChatCompletionV2StreamErrorChunk(BaseModel): + error: StreamError + + +ChatCompletionV2Chunk: TypeAlias = ( + ChatCompletionV2StreamSuccessChunk | ChatCompletionV2StreamErrorChunk +) + +ChatCompletionV2Response: TypeAlias = ChatCompletionV2SyncResponse diff --git a/clients/python/llmengine/data_types/completion.py b/clients/python/llmengine/data_types/completion.py new file mode 100644 index 00000000..67384427 --- /dev/null +++ b/clients/python/llmengine/data_types/completion.py @@ -0,0 +1,305 @@ +from typing import Any, Dict, List, Optional, TypeAlias + +from .core import StreamError +from .gen.openai import CreateCompletionRequest, CreateCompletionResponse +from .pydantic_types import BaseModel, Field +from .vllm import VLLMCompletionAdditionalParams + +# Fields that are a part of OpenAI spec but are not supported by model engine +UNSUPPORTED_FIELDS = ["service_tier"] + + +class CompletionSyncV1Request(BaseModel): + """ + Request object for a synchronous prompt completion task. + """ + + prompt: str + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + include_stop_str_in_output: Optional[bool] = None + """ + Whether to include the stop strings in output text. + """ + guided_json: Optional[Dict[str, Any]] = None + """ + JSON schema for guided decoding. Only supported in vllm. + """ + guided_regex: Optional[str] = None + """ + Regex for guided decoding. Only supported in vllm. + """ + guided_choice: Optional[List[str]] = None + """ + Choices for guided decoding. Only supported in vllm. + """ + guided_grammar: Optional[str] = None + """ + Context-free grammar for guided decoding. Only supported in vllm. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. Only supported in vllm. + """ + + +class TokenOutput(BaseModel): + """ + Detailed token information. + """ + + token: str + """ + The token text. + """ + + log_prob: float + """ + The log probability of the token. + """ + + +class CompletionOutput(BaseModel): + """ + Represents the output of a completion request to a model. + """ + + text: str + """The text of the completion.""" + + # We're not guaranteed to have `num_prompt_tokens` in the response in all cases, so to be safe, set a default. + # If we send request to api.spellbook.scale.com, we don't get this back. + num_prompt_tokens: Optional[int] = None + """Number of tokens in the prompt.""" + + num_completion_tokens: int + """Number of tokens in the completion.""" + + tokens: Optional[List[TokenOutput]] = None + """Detailed token information.""" + + +class CompletionSyncV1Response(BaseModel): + """ + Response object for a synchronous prompt completion. + """ + + request_id: str + """The unique ID of the corresponding Completion request. This `request_id` is generated on the server, and all logs + associated with the request are grouped by the `request_id`, which allows for easier troubleshooting of errors as + follows: + + * When running the *Scale-hosted* LLM Engine, please provide the `request_id` in any bug reports. + * When running the *self-hosted* LLM Engine, the `request_id` serves as a trace ID in your observability + provider.""" + + output: CompletionOutput + """Completion output.""" + + +class CompletionStreamV1Request(BaseModel): + """ + Request object for a stream prompt completion task. + """ + + prompt: str + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + include_stop_str_in_output: Optional[bool] = None + """ + Whether to include the stop strings in output text. + """ + guided_json: Optional[Dict[str, Any]] = None + """ + JSON schema for guided decoding. Only supported in vllm. + """ + guided_regex: Optional[str] = None + """ + Regex for guided decoding. Only supported in vllm. + """ + guided_choice: Optional[List[str]] = None + """ + Choices for guided decoding. Only supported in vllm. + """ + guided_grammar: Optional[str] = None + """ + Context-free grammar for guided decoding. Only supported in vllm. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. Only supported in vllm. + """ + + +class CompletionStreamOutput(BaseModel): + text: str + """The text of the completion.""" + + finished: bool + """Whether the completion is finished.""" + + # We're not guaranteed to have `num_prompt_tokens` in the response in all cases, so to be safe, set a default. + num_prompt_tokens: Optional[int] = None + """Number of tokens in the prompt.""" + + num_completion_tokens: Optional[int] = None + """Number of tokens in the completion.""" + + token: Optional[TokenOutput] = None + """Detailed token information.""" + + +class CompletionStreamV1Response(BaseModel): + """Error of the response (if any).""" + + """ + Response object for a stream prompt completion task. + """ + + request_id: str + """The unique ID of the corresponding Completion request. This `request_id` is generated on the server, and all logs + associated with the request are grouped by the `request_id`, which allows for easier troubleshooting of errors as + follows: + + * When running the *Scale-hosted* LLM Engine, please provide the `request_id` in any bug reports. + * When running the *self-hosted* LLM Engine, the `request_id` serves as a trace ID in your observability + provider.""" + + output: Optional[CompletionStreamOutput] = None + """Completion output.""" + + error: Optional[StreamError] = None + """Error of the response (if any).""" + + +class TokenUsage(BaseModel): + """ + Token usage for a prompt completion task. + """ + + num_prompt_tokens: Optional[int] = 0 + num_completion_tokens: Optional[int] = 0 + total_duration: Optional[float] = None + """Includes time spent waiting for the model to be ready.""" + + time_to_first_token: Optional[float] = None # Only for streaming requests + + @property + def num_total_tokens(self) -> int: + return (self.num_prompt_tokens or 0) + (self.num_completion_tokens or 0) + + @property + def total_tokens_per_second(self) -> float: + return ( + self.num_total_tokens / self.total_duration + if self.total_duration and self.total_duration > 0 + else 0.0 + ) + + @property + def inter_token_latency(self) -> Optional[float]: # Only for streaming requests + # Note: we calculate a single inter-token latency for the entire request. + # Calculating latency between each token seems a bit heavyweight, although we can do this if we wanted + if ( + self.time_to_first_token is None + or self.num_completion_tokens is None + or self.total_duration is None + ): + return None + if self.num_completion_tokens < 2: + return None + return (self.total_duration - self.time_to_first_token) / (self.num_completion_tokens - 1) + + +class CompletionV2Request(CreateCompletionRequest, VLLMCompletionAdditionalParams): + model: str = Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ) + + stream: Optional[bool] = Field( + False, + description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ) + + top_k: Optional[int] = Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ) + + include_stop_str_in_output: Optional[bool] = Field( + None, + description="Whether to include the stop strings in output text.", + ) + + +CompletionV2SyncResponse: TypeAlias = CreateCompletionResponse +CompletionV2StreamSuccessChunk: TypeAlias = CreateCompletionResponse + + +class CompletionV2StreamErrorChunk(BaseModel): + error: StreamError + + +CompletionV2StreamChunk: TypeAlias = CompletionV2StreamSuccessChunk | CompletionV2StreamErrorChunk + +CompletionV2Response: TypeAlias = CompletionV2SyncResponse diff --git a/clients/python/llmengine/data_types/core.py b/clients/python/llmengine/data_types/core.py new file mode 100644 index 00000000..93961a9e --- /dev/null +++ b/clients/python/llmengine/data_types/core.py @@ -0,0 +1,84 @@ +from enum import Enum +from typing import Literal, Union + +from .pydantic_types import BaseModel, Field + +CpuSpecificationType = Union[str, int, float] +StorageSpecificationType = Union[str, int, float] + + +class LLMInferenceFramework(str, Enum): + DEEPSPEED = "deepspeed" + TEXT_GENERATION_INFERENCE = "text_generation_inference" + VLLM = "vllm" + LIGHTLLM = "lightllm" + TENSORRT_LLM = "tensorrt_llm" + + +class LLMSource(str, Enum): + HUGGING_FACE = "hugging_face" + + +class Quantization(str, Enum): + BITSANDBYTES = "bitsandbytes" + AWQ = "awq" + + +class GpuType(str, Enum): + """Lists allowed GPU types for LLMEngine.""" + + NVIDIA_TESLA_T4 = "nvidia-tesla-t4" + NVIDIA_AMPERE_A10 = "nvidia-ampere-a10" + NVIDIA_AMPERE_A100 = "nvidia-ampere-a100" + NVIDIA_AMPERE_A100E = "nvidia-ampere-a100e" + NVIDIA_HOPPER_H100 = "nvidia-hopper-h100" + NVIDIA_HOPPER_H100_1G_20GB = "nvidia-hopper-h100-1g20gb" + NVIDIA_HOPPER_H100_3G_40GB = "nvidia-hopper-h100-3g40gb" + + +class ModelEndpointType(str, Enum): + STREAMING = "streaming" + + +class ModelEndpointStatus(str, Enum): + # Duplicates common/types::EndpointStatus, when refactor is done, delete the old type + # See EndpointStatus for status explanations + READY = "READY" + UPDATE_PENDING = "UPDATE_PENDING" + UPDATE_IN_PROGRESS = "UPDATE_IN_PROGRESS" + UPDATE_FAILED = "UPDATE_FAILED" + DELETE_IN_PROGRESS = "DELETE_IN_PROGRESS" + + +class CallbackBasicAuth(BaseModel): + kind: Literal["basic"] + username: str + password: str + + +class CallbackmTLSAuth(BaseModel): + kind: Literal["mtls"] + cert: str + key: str + + +class CallbackAuth(BaseModel): + __root__: Union[CallbackBasicAuth, CallbackmTLSAuth] = Field(..., discriminator="kind") + + +class StreamErrorContent(BaseModel): + error: str + """Error message.""" + timestamp: str + """Timestamp of the error.""" + + +class StreamError(BaseModel): + """ + Error object for a stream prompt completion task. + """ + + status_code: int + """The HTTP status code of the error.""" + content: StreamErrorContent + """The error content.""" diff --git a/server/llm_engine_server/__init__.py b/clients/python/llmengine/data_types/gen/__init__.py similarity index 100% rename from server/llm_engine_server/__init__.py rename to clients/python/llmengine/data_types/gen/__init__.py diff --git a/clients/python/llmengine/data_types/gen/openai.py b/clients/python/llmengine/data_types/gen/openai.py new file mode 100644 index 00000000..a97f0fd3 --- /dev/null +++ b/clients/python/llmengine/data_types/gen/openai.py @@ -0,0 +1,5832 @@ +# mypy: ignore-errors +# generated by datamodel-codegen: +# filename: openai-spec.yaml +# timestamp: 2024-08-22T02:56:18+00:00 + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Union + +import pydantic + +PYDANTIC_V2 = hasattr(pydantic, "VERSION") and pydantic.VERSION.startswith("2.") +if PYDANTIC_V2: + from pydantic.v1 import AnyUrl, BaseModel, Extra, Field # noqa: F401 +else: + from pydantic import AnyUrl, BaseModel, Extra, Field # type: ignore # noqa: F401 +from typing_extensions import Annotated, Literal + + +class Error(BaseModel): + code: str + message: str + param: str + type: str + + +class ErrorResponse(BaseModel): + error: Error + + +class DeleteModelResponse(BaseModel): + id: str + deleted: bool + object: str + + +class Prompt(BaseModel): + __root__: Annotated[ + Optional[List[int]], + Field( + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", + example="[1212, 318, 257, 1332, 13]", + min_items=1, + ), + ] = "<|endoftext|>" + + +class Prompt1Item(BaseModel): + __root__: Annotated[List[int], Field(min_items=1)] + + +class Prompt1(BaseModel): + __root__: Annotated[ + Optional[List[Prompt1Item]], + Field( + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", + example="[[1212, 318, 257, 1332, 13]]", + min_items=1, + ), + ] = "<|endoftext|>" + + +class Stop(BaseModel): + __root__: Annotated[ + Optional[List[str]], + Field( + description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n", + max_items=4, + min_items=1, + ), + ] = None + + +class Logprobs(BaseModel): + text_offset: Optional[List[int]] = None + token_logprobs: Optional[List[float]] = None + tokens: Optional[List[str]] = None + top_logprobs: Optional[List[Dict[str, float]]] = None + + +class Choice(BaseModel): + finish_reason: Annotated[ + Literal["stop", "length", "content_filter"], + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\nor `content_filter` if content was omitted due to a flag from our content filters.\n" + ), + ] + index: int + logprobs: Logprobs + text: str + + +class ChatCompletionRequestMessageContentPartText(BaseModel): + type: Annotated[Literal["text"], Field(description="The type of the content part.")] + text: Annotated[str, Field(description="The text content.")] + + +class ImageUrl(BaseModel): + url: Annotated[ + AnyUrl, + Field(description="Either a URL of the image or the base64 encoded image data."), + ] + detail: Annotated[ + Optional[Literal["auto", "low", "high"]], + Field( + description="Specifies the detail level of the image. Learn more in the [Vision guide](/docs/guides/vision/low-or-high-fidelity-image-understanding)." + ), + ] = "auto" + + +class ChatCompletionRequestMessageContentPartImage(BaseModel): + type: Annotated[Literal["image_url"], Field(description="The type of the content part.")] + image_url: ImageUrl + + +class ChatCompletionRequestMessageContentPartRefusal(BaseModel): + type: Annotated[Literal["refusal"], Field(description="The type of the content part.")] + refusal: Annotated[str, Field(description="The refusal message generated by the model.")] + + +class ChatCompletionRequestSystemMessageContentPart(BaseModel): + __root__: ChatCompletionRequestMessageContentPartText + + +class ChatCompletionRequestUserMessageContentPart(BaseModel): + __root__: Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartImage, + ] + + +class ChatCompletionRequestAssistantMessageContentPart(BaseModel): + __root__: Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartRefusal, + ] + + +class ChatCompletionRequestToolMessageContentPart(BaseModel): + __root__: ChatCompletionRequestMessageContentPartText + + +class Content(BaseModel): + __root__: Annotated[ + List[ChatCompletionRequestSystemMessageContentPart], + Field( + description="An array of content parts with a defined type. For system messages, only type `text` is supported.", + min_items=1, + title="Array of content parts", + ), + ] + + +class ChatCompletionRequestSystemMessage(BaseModel): + content: Annotated[ + Union[str, Content], Field(description="The contents of the system message.") + ] + role: Annotated[ + Literal["system"], + Field(description="The role of the messages author, in this case `system`."), + ] + name: Annotated[ + Optional[str], + Field( + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role." + ), + ] = None + + +class Content1(BaseModel): + __root__: Annotated[ + List[ChatCompletionRequestUserMessageContentPart], + Field( + description="An array of content parts with a defined type, each can be of type `text` or `image_url` when passing in images. You can pass multiple images by adding multiple `image_url` content parts. Image input is only supported when using the `gpt-4o` model.", + min_items=1, + title="Array of content parts", + ), + ] + + +class ChatCompletionRequestUserMessage(BaseModel): + content: Annotated[ + Union[str, Content1], Field(description="The contents of the user message.\n") + ] + role: Annotated[ + Literal["user"], + Field(description="The role of the messages author, in this case `user`."), + ] + name: Annotated[ + Optional[str], + Field( + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role." + ), + ] = None + + +class Content2(BaseModel): + __root__: Annotated[ + Optional[List[ChatCompletionRequestAssistantMessageContentPart]], + Field( + description="An array of content parts with a defined type. Can be one or more of type `text`, or exactly one of type `refusal`.", + min_items=1, + title="Array of content parts", + ), + ] = None + + +class FunctionCall(BaseModel): + arguments: Annotated[ + str, + Field( + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." + ), + ] + name: Annotated[str, Field(description="The name of the function to call.")] + + +class Content3(BaseModel): + __root__: Annotated[ + List[ChatCompletionRequestToolMessageContentPart], + Field( + description="An array of content parts with a defined type. For tool messages, only type `text` is supported.", + min_items=1, + title="Array of content parts", + ), + ] + + +class ChatCompletionRequestToolMessage(BaseModel): + role: Annotated[ + Literal["tool"], + Field(description="The role of the messages author, in this case `tool`."), + ] + content: Annotated[Union[str, Content3], Field(description="The contents of the tool message.")] + tool_call_id: Annotated[str, Field(description="Tool call that this message is responding to.")] + + +class ChatCompletionRequestFunctionMessage(BaseModel): + role: Annotated[ + Literal["function"], + Field(description="The role of the messages author, in this case `function`."), + ] + content: Annotated[str, Field(description="The contents of the function message.")] + name: Annotated[str, Field(description="The name of the function to call.")] + + +class FunctionParameters(BaseModel): + pass + + class Config: + extra = Extra.allow + + +class ChatCompletionFunctions(BaseModel): + description: Annotated[ + Optional[str], + Field( + description="A description of what the function does, used by the model to choose when and how to call the function." + ), + ] = None + name: Annotated[ + str, + Field( + description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." + ), + ] + parameters: Optional[FunctionParameters] = None + + +class ChatCompletionFunctionCallOption(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + + +class FunctionObject(BaseModel): + description: Annotated[ + Optional[str], + Field( + description="A description of what the function does, used by the model to choose when and how to call the function." + ), + ] = None + name: Annotated[ + str, + Field( + description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." + ), + ] + parameters: Optional[FunctionParameters] = None + strict: Annotated[ + Optional[bool], + Field( + description="Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling)." + ), + ] = False + + +class ResponseFormatText(BaseModel): + type: Annotated[ + Literal["text"], + Field(description="The type of response format being defined: `text`"), + ] + + +class ResponseFormatJsonObject(BaseModel): + type: Annotated[ + Literal["json_object"], + Field(description="The type of response format being defined: `json_object`"), + ] + + +class ResponseFormatJsonSchemaSchema(BaseModel): + pass + + class Config: + extra = Extra.allow + + +class JsonSchema(BaseModel): + description: Annotated[ + Optional[str], + Field( + description="A description of what the response format is for, used by the model to determine how to respond in the format." + ), + ] = None + name: Annotated[ + str, + Field( + description="The name of the response format. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." + ), + ] + schema_: Annotated[Optional[ResponseFormatJsonSchemaSchema], Field(alias="schema")] = None + strict: Annotated[ + Optional[bool], + Field( + description="Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the `schema` field. Only a subset of JSON Schema is supported when `strict` is `true`. To learn more, read the [Structured Outputs guide](/docs/guides/structured-outputs)." + ), + ] = False + + +class ResponseFormatJsonSchema(BaseModel): + type: Annotated[ + Literal["json_schema"], + Field(description="The type of response format being defined: `json_schema`"), + ] + json_schema: JsonSchema + + +class Function(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + + +class ChatCompletionNamedToolChoice(BaseModel): + type: Annotated[ + Literal["function"], + Field(description="The type of the tool. Currently, only `function` is supported."), + ] + function: Function + + +class ParallelToolCalls(BaseModel): + __root__: Annotated[ + bool, + Field( + description="Whether to enable [parallel function calling](/docs/guides/function-calling/parallel-function-calling) during tool use." + ), + ] + + +class Function1(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + arguments: Annotated[ + str, + Field( + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." + ), + ] + + +class ChatCompletionMessageToolCall(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call.")] + type: Annotated[ + Literal["function"], + Field(description="The type of the tool. Currently, only `function` is supported."), + ] + function: Annotated[Function1, Field(description="The function that the model called.")] + + +class Function2(BaseModel): + name: Annotated[Optional[str], Field(description="The name of the function to call.")] = None + arguments: Annotated[ + Optional[str], + Field( + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." + ), + ] = None + + +class ChatCompletionMessageToolCallChunk(BaseModel): + index: int + id: Annotated[Optional[str], Field(description="The ID of the tool call.")] = None + type: Annotated[ + Optional[Literal["function"]], + Field(description="The type of the tool. Currently, only `function` is supported."), + ] = None + function: Optional[Function2] = None + + +class ChatCompletionRole(BaseModel): + __root__: Annotated[ + Literal["system", "user", "assistant", "tool", "function"], + Field(description="The role of the author of a message"), + ] + + +class ChatCompletionStreamOptions(BaseModel): + include_usage: Annotated[ + Optional[bool], + Field( + description="If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value.\n" + ), + ] = None + + +class FunctionCall2(BaseModel): + arguments: Annotated[ + Optional[str], + Field( + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." + ), + ] = None + name: Annotated[Optional[str], Field(description="The name of the function to call.")] = None + + +class ChatCompletionStreamResponseDelta(BaseModel): + content: Annotated[ + Optional[str], Field(description="The contents of the chunk message.") + ] = None + function_call: Annotated[ + Optional[FunctionCall2], + Field( + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." + ), + ] = None + tool_calls: Optional[List[ChatCompletionMessageToolCallChunk]] = None + role: Annotated[ + Optional[Literal["system", "user", "assistant", "tool"]], + Field(description="The role of the author of this message."), + ] = None + refusal: Annotated[ + Optional[str], Field(description="The refusal message generated by the model.") + ] = None + + +class Stop1(BaseModel): + __root__: Annotated[ + List[str], + Field( + description="Up to 4 sequences where the API will stop generating further tokens.\n", + max_items=4, + min_items=1, + ), + ] + + +class TopLogprob(BaseModel): + token: Annotated[str, Field(description="The token.")] + logprob: Annotated[ + float, + Field( + description="The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely." + ), + ] + bytes: Annotated[ + List[int], + Field( + description="A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token." + ), + ] + + +class ChatCompletionTokenLogprob(BaseModel): + token: Annotated[str, Field(description="The token.")] + logprob: Annotated[ + float, + Field( + description="The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely." + ), + ] + bytes: Annotated[ + List[int], + Field( + description="A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token." + ), + ] + top_logprobs: Annotated[ + List[TopLogprob], + Field( + description="List of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested `top_logprobs` returned." + ), + ] + + +class Logprobs2(BaseModel): + content: Annotated[ + List[ChatCompletionTokenLogprob], + Field(description="A list of message content tokens with log probability information."), + ] + refusal: Annotated[ + List[ChatCompletionTokenLogprob], + Field(description="A list of message refusal tokens with log probability information."), + ] + + +class Choice3(BaseModel): + delta: ChatCompletionStreamResponseDelta + logprobs: Annotated[ + Optional[Logprobs2], + Field(description="Log probability information for the choice."), + ] = None + finish_reason: Annotated[ + Literal["stop", "length", "tool_calls", "content_filter", "function_call"], + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n" + ), + ] + index: Annotated[int, Field(description="The index of the choice in the list of choices.")] + + +class Usage(BaseModel): + completion_tokens: Annotated[ + int, Field(description="Number of tokens in the generated completion.") + ] + prompt_tokens: Annotated[int, Field(description="Number of tokens in the prompt.")] + total_tokens: Annotated[ + int, + Field(description="Total number of tokens used in the request (prompt + completion)."), + ] + + +class CreateChatCompletionStreamResponse(BaseModel): + id: Annotated[ + str, + Field( + description="A unique identifier for the chat completion. Each chunk has the same ID." + ), + ] + choices: Annotated[ + List[Choice3], + Field( + description='A list of chat completion choices. Can contain more than one elements if `n` is greater than 1. Can also be empty for the\nlast chunk if you set `stream_options: {"include_usage": true}`.\n' + ), + ] + created: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp." + ), + ] + model: Annotated[str, Field(description="The model to generate the completion.")] + service_tier: Annotated[ + Optional[Literal["scale", "default"]], + Field( + description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", + example="scale", + ), + ] = None + system_fingerprint: Annotated[ + Optional[str], + Field( + description="This fingerprint represents the backend configuration that the model runs with.\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" + ), + ] = None + object: Annotated[ + Literal["chat.completion.chunk"], + Field(description="The object type, which is always `chat.completion.chunk`."), + ] + usage: Annotated[ + Optional[Usage], + Field( + description='An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request.\nWhen present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request.\n' + ), + ] = None + + +class CreateChatCompletionImageResponse(BaseModel): + pass + + +class CreateImageRequest(BaseModel): + prompt: Annotated[ + str, + Field( + description="A text description of the desired image(s). The maximum length is 1000 characters for `dall-e-2` and 4000 characters for `dall-e-3`.", + example="A cute baby sea otter", + ), + ] + model: Annotated[ + Optional[Union[str, Literal["dall-e-2", "dall-e-3"]]], + Field(description="The model to use for image generation.", example="dall-e-3"), + ] = "dall-e-2" + n: Annotated[ + Optional[int], + Field( + description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", + example=1, + ge=1, + le=10, + ), + ] = 1 + quality: Annotated[ + Optional[Literal["standard", "hd"]], + Field( + description="The quality of the image that will be generated. `hd` creates images with finer details and greater consistency across the image. This param is only supported for `dall-e-3`.", + example="standard", + ), + ] = "standard" + response_format: Annotated[ + Optional[Literal["url", "b64_json"]], + Field( + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", + example="url", + ), + ] = "url" + size: Annotated[ + Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]], + Field( + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models.", + example="1024x1024", + ), + ] = "1024x1024" + style: Annotated[ + Optional[Literal["vivid", "natural"]], + Field( + description="The style of the generated images. Must be one of `vivid` or `natural`. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for `dall-e-3`.", + example="vivid", + ), + ] = "vivid" + user: Annotated[ + Optional[str], + Field( + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + example="user-1234", + ), + ] = None + + +class Image(BaseModel): + b64_json: Annotated[ + Optional[str], + Field( + description="The base64-encoded JSON of the generated image, if `response_format` is `b64_json`." + ), + ] = None + url: Annotated[ + Optional[str], + Field( + description="The URL of the generated image, if `response_format` is `url` (default)." + ), + ] = None + revised_prompt: Annotated[ + Optional[str], + Field( + description="The prompt that was used to generate the image, if there was any revision to the prompt." + ), + ] = None + + +class CreateImageEditRequest(BaseModel): + image: Annotated[ + bytes, + Field( + description="The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask is not provided, image must have transparency, which will be used as the mask." + ), + ] + prompt: Annotated[ + str, + Field( + description="A text description of the desired image(s). The maximum length is 1000 characters.", + example="A cute baby sea otter wearing a beret", + ), + ] + mask: Annotated[ + Optional[bytes], + Field( + description="An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`." + ), + ] = None + model: Annotated[ + Optional[Union[str, Literal["dall-e-2"]]], + Field( + description="The model to use for image generation. Only `dall-e-2` is supported at this time.", + example="dall-e-2", + ), + ] = "dall-e-2" + n: Annotated[ + Optional[int], + Field( + description="The number of images to generate. Must be between 1 and 10.", + example=1, + ge=1, + le=10, + ), + ] = 1 + size: Annotated[ + Optional[Literal["256x256", "512x512", "1024x1024"]], + Field( + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", + example="1024x1024", + ), + ] = "1024x1024" + response_format: Annotated[ + Optional[Literal["url", "b64_json"]], + Field( + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", + example="url", + ), + ] = "url" + user: Annotated[ + Optional[str], + Field( + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + example="user-1234", + ), + ] = None + + +class CreateImageVariationRequest(BaseModel): + image: Annotated[ + bytes, + Field( + description="The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square." + ), + ] + model: Annotated[ + Optional[Union[str, Literal["dall-e-2"]]], + Field( + description="The model to use for image generation. Only `dall-e-2` is supported at this time.", + example="dall-e-2", + ), + ] = "dall-e-2" + n: Annotated[ + Optional[int], + Field( + description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", + example=1, + ge=1, + le=10, + ), + ] = 1 + response_format: Annotated[ + Optional[Literal["url", "b64_json"]], + Field( + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", + example="url", + ), + ] = "url" + size: Annotated[ + Optional[Literal["256x256", "512x512", "1024x1024"]], + Field( + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", + example="1024x1024", + ), + ] = "1024x1024" + user: Annotated[ + Optional[str], + Field( + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + example="user-1234", + ), + ] = None + + +class CreateModerationRequest(BaseModel): + input: Annotated[Union[str, List[str]], Field(description="The input text to classify")] + model: Annotated[ + Optional[Union[str, Literal["text-moderation-latest", "text-moderation-stable"]]], + Field( + description="Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`.\n\nThe default is `text-moderation-latest` which will be automatically upgraded over time. This ensures you are always using our most accurate model. If you use `text-moderation-stable`, we will provide advanced notice before updating the model. Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`.\n", + example="text-moderation-stable", + ), + ] = "text-moderation-latest" + + +class Categories(BaseModel): + hate: Annotated[ + bool, + Field( + description="Content that expresses, incites, or promotes hate based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste. Hateful content aimed at non-protected groups (e.g., chess players) is harassment." + ), + ] + hate_threatening: Annotated[ + bool, + Field( + alias="hate/threatening", + description="Hateful content that also includes violence or serious harm towards the targeted group based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste.", + ), + ] + harassment: Annotated[ + bool, + Field( + description="Content that expresses, incites, or promotes harassing language towards any target." + ), + ] + harassment_threatening: Annotated[ + bool, + Field( + alias="harassment/threatening", + description="Harassment content that also includes violence or serious harm towards any target.", + ), + ] + self_harm: Annotated[ + bool, + Field( + alias="self-harm", + description="Content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders.", + ), + ] + self_harm_intent: Annotated[ + bool, + Field( + alias="self-harm/intent", + description="Content where the speaker expresses that they are engaging or intend to engage in acts of self-harm, such as suicide, cutting, and eating disorders.", + ), + ] + self_harm_instructions: Annotated[ + bool, + Field( + alias="self-harm/instructions", + description="Content that encourages performing acts of self-harm, such as suicide, cutting, and eating disorders, or that gives instructions or advice on how to commit such acts.", + ), + ] + sexual: Annotated[ + bool, + Field( + description="Content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness)." + ), + ] + sexual_minors: Annotated[ + bool, + Field( + alias="sexual/minors", + description="Sexual content that includes an individual who is under 18 years old.", + ), + ] + violence: Annotated[ + bool, + Field(description="Content that depicts death, violence, or physical injury."), + ] + violence_graphic: Annotated[ + bool, + Field( + alias="violence/graphic", + description="Content that depicts death, violence, or physical injury in graphic detail.", + ), + ] + + +class CategoryScores(BaseModel): + hate: Annotated[float, Field(description="The score for the category 'hate'.")] + hate_threatening: Annotated[ + float, + Field( + alias="hate/threatening", + description="The score for the category 'hate/threatening'.", + ), + ] + harassment: Annotated[float, Field(description="The score for the category 'harassment'.")] + harassment_threatening: Annotated[ + float, + Field( + alias="harassment/threatening", + description="The score for the category 'harassment/threatening'.", + ), + ] + self_harm: Annotated[ + float, + Field(alias="self-harm", description="The score for the category 'self-harm'."), + ] + self_harm_intent: Annotated[ + float, + Field( + alias="self-harm/intent", + description="The score for the category 'self-harm/intent'.", + ), + ] + self_harm_instructions: Annotated[ + float, + Field( + alias="self-harm/instructions", + description="The score for the category 'self-harm/instructions'.", + ), + ] + sexual: Annotated[float, Field(description="The score for the category 'sexual'.")] + sexual_minors: Annotated[ + float, + Field( + alias="sexual/minors", + description="The score for the category 'sexual/minors'.", + ), + ] + violence: Annotated[float, Field(description="The score for the category 'violence'.")] + violence_graphic: Annotated[ + float, + Field( + alias="violence/graphic", + description="The score for the category 'violence/graphic'.", + ), + ] + + +class Result(BaseModel): + flagged: Annotated[bool, Field(description="Whether any of the below categories are flagged.")] + categories: Annotated[ + Categories, + Field(description="A list of the categories, and whether they are flagged or not."), + ] + category_scores: Annotated[ + CategoryScores, + Field( + description="A list of the categories along with their scores as predicted by model." + ), + ] + + +class CreateModerationResponse(BaseModel): + id: Annotated[str, Field(description="The unique identifier for the moderation request.")] + model: Annotated[str, Field(description="The model used to generate the moderation results.")] + results: Annotated[List[Result], Field(description="A list of moderation objects.")] + + +class CreateFileRequest(BaseModel): + class Config: + extra = Extra.forbid + + file: Annotated[bytes, Field(description="The File object (not file name) to be uploaded.\n")] + purpose: Annotated[ + Literal["assistants", "batch", "fine-tune", "vision"], + Field( + description='The intended purpose of the uploaded file.\n\nUse "assistants" for [Assistants](/docs/api-reference/assistants) and [Message](/docs/api-reference/messages) files, "vision" for Assistants image file inputs, "batch" for [Batch API](/docs/guides/batch), and "fine-tune" for [Fine-tuning](/docs/api-reference/fine-tuning).\n' + ), + ] + + +class DeleteFileResponse(BaseModel): + id: str + object: Literal["file"] + deleted: bool + + +class CreateUploadRequest(BaseModel): + class Config: + extra = Extra.forbid + + filename: Annotated[str, Field(description="The name of the file to upload.\n")] + purpose: Annotated[ + Literal["assistants", "batch", "fine-tune", "vision"], + Field( + description="The intended purpose of the uploaded file.\n\nSee the [documentation on File purposes](/docs/api-reference/files/create#files-create-purpose).\n" + ), + ] + bytes: Annotated[int, Field(description="The number of bytes in the file you are uploading.\n")] + mime_type: Annotated[ + str, + Field( + description="The MIME type of the file.\n\nThis must fall within the supported MIME types for your file purpose. See the supported MIME types for assistants and vision.\n" + ), + ] + + +class AddUploadPartRequest(BaseModel): + class Config: + extra = Extra.forbid + + data: Annotated[bytes, Field(description="The chunk of bytes for this Part.\n")] + + +class CompleteUploadRequest(BaseModel): + class Config: + extra = Extra.forbid + + part_ids: Annotated[List[str], Field(description="The ordered list of Part IDs.\n")] + md5: Annotated[ + Optional[str], + Field( + description="The optional md5 checksum for the file contents to verify if the bytes uploaded matches what you expect.\n" + ), + ] = None + + +class CancelUploadRequest(BaseModel): + pass + + class Config: + extra = Extra.forbid + + +class BatchSize(BaseModel): + __root__: Annotated[ + int, + Field( + description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n", + ge=1, + le=256, + ), + ] + + +class LearningRateMultiplier(BaseModel): + __root__: Annotated[ + float, + Field( + description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n", + gt=0.0, + ), + ] + + +class NEpochs(BaseModel): + __root__: Annotated[ + int, + Field( + description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n", + ge=1, + le=50, + ), + ] + + +class Hyperparameters(BaseModel): + batch_size: Annotated[ + Optional[Union[Literal["auto"], BatchSize]], + Field( + description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n" + ), + ] = "auto" + learning_rate_multiplier: Annotated[ + Optional[Union[Literal["auto"], LearningRateMultiplier]], + Field( + description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n" + ), + ] = "auto" + n_epochs: Annotated[ + Optional[Union[Literal["auto"], NEpochs]], + Field( + description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n" + ), + ] = "auto" + + +class Wandb(BaseModel): + project: Annotated[ + str, + Field( + description="The name of the project that the new run will be created under.\n", + example="my-wandb-project", + ), + ] + name: Annotated[ + Optional[str], + Field( + description="A display name to set for the run. If not set, we will use the Job ID as the name.\n" + ), + ] = None + entity: Annotated[ + Optional[str], + Field( + description="The entity to use for the run. This allows you to set the team or username of the WandB user that you would\nlike associated with the run. If not set, the default entity for the registered WandB API key is used.\n" + ), + ] = None + tags: Annotated[ + Optional[List[str]], + Field( + description='A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some\ndefault tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".\n' + ), + ] = None + + +class Integration(BaseModel): + type: Annotated[ + Literal["wandb"], + Field( + description='The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported.\n' + ), + ] + wandb: Annotated[ + Wandb, + Field( + description="The settings for your integration with Weights and Biases. This payload specifies the project that\nmetrics will be sent to. Optionally, you can set an explicit display name for your run, add tags\nto your run, and set a default entity (team, username, etc) to be associated with your run.\n" + ), + ] + + +class CreateFineTuningJobRequest(BaseModel): + model: Annotated[ + Union[str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo", "gpt-4o-mini"]], + Field( + description="The name of the model to fine-tune. You can select one of the\n[supported models](/docs/guides/fine-tuning/which-models-can-be-fine-tuned).\n", + example="gpt-4o-mini", + ), + ] + training_file: Annotated[ + str, + Field( + description="The ID of an uploaded file that contains training data.\n\nSee [upload file](/docs/api-reference/files/create) for how to upload a file.\n\nYour dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`.\n\nThe contents of the file should differ depending on if the model uses the [chat](/docs/api-reference/fine-tuning/chat-input) or [completions](/docs/api-reference/fine-tuning/completions-input) format.\n\nSee the [fine-tuning guide](/docs/guides/fine-tuning) for more details.\n", + example="file-abc123", + ), + ] + hyperparameters: Annotated[ + Optional[Hyperparameters], + Field(description="The hyperparameters used for the fine-tuning job."), + ] = None + suffix: Annotated[ + Optional[str], + Field( + description='A string of up to 18 characters that will be added to your fine-tuned model name.\n\nFor example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-4o-mini:openai:custom-model-name:7p4lURel`.\n', + max_length=40, + min_length=1, + ), + ] = None + validation_file: Annotated[ + Optional[str], + Field( + description="The ID of an uploaded file that contains validation data.\n\nIf you provide this file, the data is used to generate validation\nmetrics periodically during fine-tuning. These metrics can be viewed in\nthe fine-tuning results file.\nThe same data should not be present in both train and validation files.\n\nYour dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`.\n\nSee the [fine-tuning guide](/docs/guides/fine-tuning) for more details.\n", + example="file-abc123", + ), + ] = None + integrations: Annotated[ + Optional[List[Integration]], + Field(description="A list of integrations to enable for your fine-tuning job."), + ] = None + seed: Annotated[ + Optional[int], + Field( + description="The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases.\nIf a seed is not specified, one will be generated for you.\n", + example=42, + ge=0, + le=2147483647, + ), + ] = None + + +class Input(BaseModel): + __root__: Annotated[ + List[str], + Field( + description="The array of strings that will be turned into an embedding.", + example="The quick brown fox jumped over the lazy dog", + max_items=2048, + min_items=1, + title="array", + ), + ] + + +class Input1(BaseModel): + __root__: Annotated[ + List[int], + Field( + description="The array of integers that will be turned into an embedding.", + example="[1212, 318, 257, 1332, 13]", + max_items=2048, + min_items=1, + title="array", + ), + ] + + +class Input2Item(BaseModel): + __root__: Annotated[List[int], Field(min_items=1)] + + +class Input2(BaseModel): + __root__: Annotated[ + List[Input2Item], + Field( + description="The array of arrays containing integers that will be turned into an embedding.", + example="[[1212, 318, 257, 1332, 13]]", + max_items=2048, + min_items=1, + title="array", + ), + ] + + +class CreateEmbeddingRequest(BaseModel): + class Config: + extra = Extra.forbid + + input: Annotated[ + Union[str, Input, Input1, Input2], + Field( + description="Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + example="The quick brown fox jumped over the lazy dog", + ), + ] + model: Annotated[ + Union[ + str, + Literal[ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + ], + ], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + example="text-embedding-3-small", + ), + ] + encoding_format: Annotated[ + Optional[Literal["float", "base64"]], + Field( + description="The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).", + example="float", + ), + ] = "float" + dimensions: Annotated[ + Optional[int], + Field( + description="The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.\n", + ge=1, + ), + ] = None + user: Annotated[ + Optional[str], + Field( + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + example="user-1234", + ), + ] = None + + +class Usage1(BaseModel): + prompt_tokens: Annotated[int, Field(description="The number of tokens used by the prompt.")] + total_tokens: Annotated[ + int, Field(description="The total number of tokens used by the request.") + ] + + +class CreateTranscriptionRequest(BaseModel): + class Config: + extra = Extra.forbid + + file: Annotated[ + bytes, + Field( + description="The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.\n" + ), + ] + model: Annotated[ + Union[str, Literal["whisper-1"]], + Field( + description="ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.\n", + example="whisper-1", + ), + ] + language: Annotated[ + Optional[str], + Field( + description="The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency.\n" + ), + ] = None + prompt: Annotated[ + Optional[str], + Field( + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should match the audio language.\n" + ), + ] = None + response_format: Annotated[ + Optional[Literal["json", "text", "srt", "verbose_json", "vtt"]], + Field( + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n" + ), + ] = "json" + temperature: Annotated[ + Optional[float], + Field( + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n" + ), + ] = 0 + timestamp_granularities__: Annotated[ + Optional[List[Literal["word", "segment"]]], + Field( + alias="timestamp_granularities[]", + description="The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. Either or both of these options are supported: `word`, or `segment`. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency.\n", + ), + ] = ["segment"] + + +class CreateTranscriptionResponseJson(BaseModel): + text: Annotated[str, Field(description="The transcribed text.")] + + +class TranscriptionSegment(BaseModel): + id: Annotated[int, Field(description="Unique identifier of the segment.")] + seek: Annotated[int, Field(description="Seek offset of the segment.")] + start: Annotated[float, Field(description="Start time of the segment in seconds.")] + end: Annotated[float, Field(description="End time of the segment in seconds.")] + text: Annotated[str, Field(description="Text content of the segment.")] + tokens: Annotated[List[int], Field(description="Array of token IDs for the text content.")] + temperature: Annotated[ + float, + Field(description="Temperature parameter used for generating the segment."), + ] + avg_logprob: Annotated[ + float, + Field( + description="Average logprob of the segment. If the value is lower than -1, consider the logprobs failed." + ), + ] + compression_ratio: Annotated[ + float, + Field( + description="Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed." + ), + ] + no_speech_prob: Annotated[ + float, + Field( + description="Probability of no speech in the segment. If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this segment silent." + ), + ] + + +class TranscriptionWord(BaseModel): + word: Annotated[str, Field(description="The text content of the word.")] + start: Annotated[float, Field(description="Start time of the word in seconds.")] + end: Annotated[float, Field(description="End time of the word in seconds.")] + + +class CreateTranscriptionResponseVerboseJson(BaseModel): + language: Annotated[str, Field(description="The language of the input audio.")] + duration: Annotated[str, Field(description="The duration of the input audio.")] + text: Annotated[str, Field(description="The transcribed text.")] + words: Annotated[ + Optional[List[TranscriptionWord]], + Field(description="Extracted words and their corresponding timestamps."), + ] = None + segments: Annotated[ + Optional[List[TranscriptionSegment]], + Field(description="Segments of the transcribed text and their corresponding details."), + ] = None + + +class CreateTranslationRequest(BaseModel): + class Config: + extra = Extra.forbid + + file: Annotated[ + bytes, + Field( + description="The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.\n" + ), + ] + model: Annotated[ + Union[str, Literal["whisper-1"]], + Field( + description="ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.\n", + example="whisper-1", + ), + ] + prompt: Annotated[ + Optional[str], + Field( + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should be in English.\n" + ), + ] = None + response_format: Annotated[ + Optional[str], + Field( + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n" + ), + ] = "json" + temperature: Annotated[ + Optional[float], + Field( + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n" + ), + ] = 0 + + +class CreateTranslationResponseJson(BaseModel): + text: str + + +class CreateTranslationResponseVerboseJson(BaseModel): + language: Annotated[ + str, + Field(description="The language of the output translation (always `english`)."), + ] + duration: Annotated[str, Field(description="The duration of the input audio.")] + text: Annotated[str, Field(description="The translated text.")] + segments: Annotated[ + Optional[List[TranscriptionSegment]], + Field(description="Segments of the translated text and their corresponding details."), + ] = None + + +class CreateSpeechRequest(BaseModel): + class Config: + extra = Extra.forbid + + model: Annotated[ + Union[str, Literal["tts-1", "tts-1-hd"]], + Field( + description="One of the available [TTS models](/docs/models/tts): `tts-1` or `tts-1-hd`\n" + ), + ] + input: Annotated[ + str, + Field( + description="The text to generate audio for. The maximum length is 4096 characters.", + max_length=4096, + ), + ] + voice: Annotated[ + Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"], + Field( + description="The voice to use when generating the audio. Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`. Previews of the voices are available in the [Text to speech guide](/docs/guides/text-to-speech/voice-options)." + ), + ] + response_format: Annotated[ + Optional[Literal["mp3", "opus", "aac", "flac", "wav", "pcm"]], + Field( + description="The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`." + ), + ] = "mp3" + speed: Annotated[ + Optional[float], + Field( + description="The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default.", + ge=0.25, + le=4.0, + ), + ] = 1.0 + + +class Model(BaseModel): + id: Annotated[ + str, + Field(description="The model identifier, which can be referenced in the API endpoints."), + ] + created: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) when the model was created."), + ] + object: Annotated[ + Literal["model"], Field(description='The object type, which is always "model".') + ] + owned_by: Annotated[str, Field(description="The organization that owns the model.")] + + +class OpenAIFile(BaseModel): + id: Annotated[ + str, + Field(description="The file identifier, which can be referenced in the API endpoints."), + ] + bytes: Annotated[int, Field(description="The size of the file, in bytes.")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the file was created."), + ] + filename: Annotated[str, Field(description="The name of the file.")] + object: Annotated[ + Literal["file"], Field(description="The object type, which is always `file`.") + ] + purpose: Annotated[ + Literal[ + "assistants", + "assistants_output", + "batch", + "batch_output", + "fine-tune", + "fine-tune-results", + "vision", + ], + Field( + description="The intended purpose of the file. Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`, `fine-tune`, `fine-tune-results` and `vision`." + ), + ] + status: Annotated[ + Literal["uploaded", "processed", "error"], + Field( + description="Deprecated. The current status of the file, which can be either `uploaded`, `processed`, or `error`." + ), + ] + status_details: Annotated[ + Optional[str], + Field( + description="Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`." + ), + ] = None + + +class Upload(BaseModel): + id: Annotated[ + str, + Field( + description="The Upload unique identifier, which can be referenced in API endpoints." + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the Upload was created."), + ] + filename: Annotated[str, Field(description="The name of the file to be uploaded.")] + bytes: Annotated[int, Field(description="The intended number of bytes to be uploaded.")] + purpose: Annotated[ + str, + Field( + description="The intended purpose of the file. [Please refer here](/docs/api-reference/files/object#files/object-purpose) for acceptable values." + ), + ] + status: Annotated[ + Literal["pending", "completed", "cancelled", "expired"], + Field(description="The status of the Upload."), + ] + expires_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the Upload was created."), + ] + object: Annotated[ + Optional[Literal["upload"]], + Field(description='The object type, which is always "upload".'), + ] = None + file: Annotated[ + Optional[OpenAIFile], + Field(description="The ready File object after the Upload is completed."), + ] = None + + +class UploadPart(BaseModel): + id: Annotated[ + str, + Field( + description="The upload Part unique identifier, which can be referenced in API endpoints." + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the Part was created."), + ] + upload_id: Annotated[ + str, + Field(description="The ID of the Upload object that this Part was added to."), + ] + object: Annotated[ + Literal["upload.part"], + Field(description="The object type, which is always `upload.part`."), + ] + + +class Embedding(BaseModel): + index: Annotated[ + int, Field(description="The index of the embedding in the list of embeddings.") + ] + embedding: Annotated[ + List[float], + Field( + description="The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings).\n" + ), + ] + object: Annotated[ + Literal["embedding"], + Field(description='The object type, which is always "embedding".'), + ] + + +class Error1(BaseModel): + code: Annotated[str, Field(description="A machine-readable error code.")] + message: Annotated[str, Field(description="A human-readable error message.")] + param: Annotated[ + str, + Field( + description="The parameter that was invalid, usually `training_file` or `validation_file`. This field will be null if the failure was not parameter-specific." + ), + ] + + +class NEpochs1(BaseModel): + __root__: Annotated[ + int, + Field( + description='The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.\n"auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs.', + ge=1, + le=50, + ), + ] + + +class Hyperparameters1(BaseModel): + n_epochs: Annotated[ + Union[Literal["auto"], NEpochs1], + Field( + description='The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.\n"auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs.' + ), + ] + + +class FineTuningIntegration(BaseModel): + type: Annotated[ + Literal["wandb"], + Field(description="The type of the integration being enabled for the fine-tuning job"), + ] + wandb: Annotated[ + Wandb, + Field( + description="The settings for your integration with Weights and Biases. This payload specifies the project that\nmetrics will be sent to. Optionally, you can set an explicit display name for your run, add tags\nto your run, and set a default entity (team, username, etc) to be associated with your run.\n" + ), + ] + + +class FineTuningJobEvent(BaseModel): + id: str + created_at: int + level: Literal["info", "warn", "error"] + message: str + object: Literal["fine_tuning.job.event"] + + +class Metrics(BaseModel): + step: Optional[float] = None + train_loss: Optional[float] = None + train_mean_token_accuracy: Optional[float] = None + valid_loss: Optional[float] = None + valid_mean_token_accuracy: Optional[float] = None + full_valid_loss: Optional[float] = None + full_valid_mean_token_accuracy: Optional[float] = None + + +class FineTuningJobCheckpoint(BaseModel): + id: Annotated[ + str, + Field( + description="The checkpoint identifier, which can be referenced in the API endpoints." + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the checkpoint was created."), + ] + fine_tuned_model_checkpoint: Annotated[ + str, + Field(description="The name of the fine-tuned checkpoint model that is created."), + ] + step_number: Annotated[ + int, Field(description="The step number that the checkpoint was created at.") + ] + metrics: Annotated[ + Metrics, + Field(description="Metrics at the step number during the fine-tuning job."), + ] + fine_tuning_job_id: Annotated[ + str, + Field(description="The name of the fine-tuning job that this checkpoint was created from."), + ] + object: Annotated[ + Literal["fine_tuning.job.checkpoint"], + Field(description='The object type, which is always "fine_tuning.job.checkpoint".'), + ] + + +class FinetuneCompletionRequestInput(BaseModel): + prompt: Annotated[ + Optional[str], Field(description="The input prompt for this training example.") + ] = None + completion: Annotated[ + Optional[str], + Field(description="The desired completion for this training example."), + ] = None + + +class CompletionUsage(BaseModel): + completion_tokens: Annotated[ + int, Field(description="Number of tokens in the generated completion.") + ] + prompt_tokens: Annotated[int, Field(description="Number of tokens in the prompt.")] + total_tokens: Annotated[ + int, + Field(description="Total number of tokens used in the request (prompt + completion)."), + ] + + +class RunCompletionUsage(BaseModel): + completion_tokens: Annotated[ + int, + Field(description="Number of completion tokens used over the course of the run."), + ] + prompt_tokens: Annotated[ + int, + Field(description="Number of prompt tokens used over the course of the run."), + ] + total_tokens: Annotated[ + int, Field(description="Total number of tokens used (prompt + completion).") + ] + + +class RunStepCompletionUsage(BaseModel): + completion_tokens: Annotated[ + int, + Field(description="Number of completion tokens used over the course of the run step."), + ] + prompt_tokens: Annotated[ + int, + Field(description="Number of prompt tokens used over the course of the run step."), + ] + total_tokens: Annotated[ + int, Field(description="Total number of tokens used (prompt + completion).") + ] + + +class AssistantsApiResponseFormatOption(BaseModel): + __root__: Annotated[ + Union[ + Literal["auto"], + ResponseFormatText, + ResponseFormatJsonObject, + ResponseFormatJsonSchema, + ], + Field( + description='Specifies the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4 Turbo](/docs/models/gpt-4-turbo-and-gpt-4), and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs).\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n' + ), + ] + + +class CodeInterpreter(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool.\n", + max_items=20, + ), + ] = [] + + +class FileSearch(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_items=1, + ), + ] = None + + +class ToolResources(BaseModel): + code_interpreter: Optional[CodeInterpreter] = None + file_search: Optional[FileSearch] = None + + +class CodeInterpreter1(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", + max_items=20, + ), + ] = [] + + +class ChunkingStrategy(BaseModel): + class Config: + extra = Extra.forbid + + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class Static(BaseModel): + class Config: + extra = Extra.forbid + + max_chunk_size_tokens: Annotated[ + int, + Field( + description="The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`.", + ge=100, + le=4096, + ), + ] + chunk_overlap_tokens: Annotated[ + int, + Field( + description="The number of tokens that overlap between chunks. The default value is `400`.\n\nNote that the overlap must not exceed half of `max_chunk_size_tokens`.\n" + ), + ] + + +class ChunkingStrategy1(BaseModel): + class Config: + extra = Extra.forbid + + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: Static + + +class VectorStore(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_items=10000, + ), + ] = None + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy, ChunkingStrategy1]], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class FileSearch1(BaseModel): + vector_store_ids: Annotated[ + List[str], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_items=1, + ), + ] + vector_stores: Annotated[ + Optional[List[VectorStore]], + Field( + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_items=1, + ), + ] = None + + +class ChunkingStrategy2(BaseModel): + class Config: + extra = Extra.forbid + + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class ChunkingStrategy3(BaseModel): + class Config: + extra = Extra.forbid + + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: Static + + +class VectorStore1(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_items=10000, + ), + ] = None + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy2, ChunkingStrategy3]], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class FileSearch2(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_items=1, + ), + ] = None + vector_stores: Annotated[ + List[VectorStore1], + Field( + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_items=1, + ), + ] + + +class ToolResources1(BaseModel): + code_interpreter: Optional[CodeInterpreter1] = None + file_search: Optional[Union[FileSearch1, FileSearch2]] = None + + +class CodeInterpreter2(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + description="Overrides the list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", + max_items=20, + ), + ] = [] + + +class FileSearch3(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="Overrides the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_items=1, + ), + ] = None + + +class ToolResources2(BaseModel): + code_interpreter: Optional[CodeInterpreter2] = None + file_search: Optional[FileSearch3] = None + + +class DeleteAssistantResponse(BaseModel): + id: str + deleted: bool + object: Literal["assistant.deleted"] + + +class AssistantToolsCode(BaseModel): + type: Annotated[ + Literal["code_interpreter"], + Field(description="The type of tool being defined: `code_interpreter`"), + ] + + +class FileSearch4(BaseModel): + max_num_results: Annotated[ + Optional[int], + Field( + description="The maximum number of results the file search tool should output. The default is 20 for `gpt-4*` models and 5 for `gpt-3.5-turbo`. This number should be between 1 and 50 inclusive.\n\nNote that the file search tool may output fewer than `max_num_results` results. See the [file search tool documentation](/docs/assistants/tools/file-search/number-of-chunks-returned) for more information.\n", + ge=1, + le=50, + ), + ] = None + + +class AssistantToolsFileSearch(BaseModel): + type: Annotated[ + Literal["file_search"], + Field(description="The type of tool being defined: `file_search`"), + ] + file_search: Annotated[ + Optional[FileSearch4], Field(description="Overrides for the file search tool.") + ] = None + + +class AssistantToolsFileSearchTypeOnly(BaseModel): + type: Annotated[ + Literal["file_search"], + Field(description="The type of tool being defined: `file_search`"), + ] + + +class AssistantToolsFunction(BaseModel): + type: Annotated[ + Literal["function"], + Field(description="The type of tool being defined: `function`"), + ] + function: FunctionObject + + +class TruncationObject(BaseModel): + type: Annotated[ + Literal["auto", "last_messages"], + Field( + description="The truncation strategy to use for the thread. The default is `auto`. If set to `last_messages`, the thread will be truncated to the n most recent messages in the thread. When set to `auto`, messages in the middle of the thread will be dropped to fit the context length of the model, `max_prompt_tokens`." + ), + ] + last_messages: Annotated[ + Optional[int], + Field( + description="The number of most recent messages from the thread when constructing the context for the run.", + ge=1, + ), + ] = None + + +class Function3(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + + +class AssistantsNamedToolChoice(BaseModel): + type: Annotated[ + Literal["function", "code_interpreter", "file_search"], + Field( + description="The type of the tool. If type is `function`, the function name must be set" + ), + ] + function: Optional[Function3] = None + + +class LastError(BaseModel): + code: Annotated[ + Literal["server_error", "rate_limit_exceeded", "invalid_prompt"], + Field(description="One of `server_error`, `rate_limit_exceeded`, or `invalid_prompt`."), + ] + message: Annotated[str, Field(description="A human-readable description of the error.")] + + +class IncompleteDetails(BaseModel): + reason: Annotated[ + Optional[Literal["max_completion_tokens", "max_prompt_tokens"]], + Field( + description="The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run." + ), + ] = None + + +class ModifyRunRequest(BaseModel): + class Config: + extra = Extra.forbid + + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class ToolOutput(BaseModel): + tool_call_id: Annotated[ + Optional[str], + Field( + description="The ID of the tool call in the `required_action` object within the run object the output is being submitted for." + ), + ] = None + output: Annotated[ + Optional[str], + Field(description="The output of the tool call to be submitted to continue the run."), + ] = None + + +class SubmitToolOutputsRunRequest(BaseModel): + class Config: + extra = Extra.forbid + + tool_outputs: Annotated[ + List[ToolOutput], + Field(description="A list of tools for which the outputs are being submitted."), + ] + stream: Annotated[ + Optional[bool], + Field( + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n" + ), + ] = None + + +class Function4(BaseModel): + name: Annotated[str, Field(description="The name of the function.")] + arguments: Annotated[ + str, + Field(description="The arguments that the model expects you to pass to the function."), + ] + + +class RunToolCallObject(BaseModel): + id: Annotated[ + str, + Field( + description="The ID of the tool call. This ID must be referenced when you submit the tool outputs in using the [Submit tool outputs to run](/docs/api-reference/runs/submitToolOutputs) endpoint." + ), + ] + type: Annotated[ + Literal["function"], + Field( + description="The type of tool call the output is required for. For now, this is always `function`." + ), + ] + function: Annotated[Function4, Field(description="The function definition.")] + + +class CodeInterpreter3(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", + max_items=20, + ), + ] = [] + + +class FileSearch5(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_items=1, + ), + ] = None + + +class ToolResources3(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[FileSearch5] = None + + +class FileSearch6(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_items=1, + ), + ] = None + + +class ToolResources4(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[FileSearch6] = None + + +class ThreadObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["thread"], + Field(description="The object type, which is always `thread`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the thread was created."), + ] + tool_resources: Annotated[ + ToolResources4, + Field( + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + + +class ChunkingStrategy4(BaseModel): + class Config: + extra = Extra.forbid + + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class ChunkingStrategy5(BaseModel): + class Config: + extra = Extra.forbid + + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: Static + + +class VectorStore2(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_items=10000, + ), + ] = None + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy4, ChunkingStrategy5]], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class FileSearch7(BaseModel): + vector_store_ids: Annotated[ + List[str], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_items=1, + ), + ] + vector_stores: Annotated[ + Optional[List[VectorStore2]], + Field( + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_items=1, + ), + ] = None + + +class ChunkingStrategy6(BaseModel): + class Config: + extra = Extra.forbid + + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class ChunkingStrategy7(BaseModel): + class Config: + extra = Extra.forbid + + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: Static + + +class VectorStore3(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_items=10000, + ), + ] = None + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy6, ChunkingStrategy7]], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class FileSearch8(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_items=1, + ), + ] = None + vector_stores: Annotated[ + List[VectorStore3], + Field( + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_items=1, + ), + ] + + +class ToolResources5(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[Union[FileSearch7, FileSearch8]] = None + + +class FileSearch9(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_items=1, + ), + ] = None + + +class ToolResources6(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[FileSearch9] = None + + +class ModifyThreadRequest(BaseModel): + class Config: + extra = Extra.forbid + + tool_resources: Annotated[ + Optional[ToolResources6], + Field( + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class DeleteThreadResponse(BaseModel): + id: str + deleted: bool + object: Literal["thread.deleted"] + + +class ListThreadsResponse(BaseModel): + object: Annotated[str, Field(example="list")] + data: List[ThreadObject] + first_id: Annotated[str, Field(example="asst_abc123")] + last_id: Annotated[str, Field(example="asst_abc456")] + has_more: Annotated[bool, Field(example=False)] + + +class IncompleteDetails1(BaseModel): + reason: Annotated[ + Literal["content_filter", "max_tokens", "run_cancelled", "run_expired", "run_failed"], + Field(description="The reason the message is incomplete."), + ] + + +class Attachment(BaseModel): + file_id: Annotated[ + Optional[str], Field(description="The ID of the file to attach to the message.") + ] = None + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearchTypeOnly]]], + Field(description="The tools to add this file to."), + ] = None + + +class ModifyMessageRequest(BaseModel): + class Config: + extra = Extra.forbid + + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class DeleteMessageResponse(BaseModel): + id: str + deleted: bool + object: Literal["thread.message.deleted"] + + +class ImageFile(BaseModel): + file_id: Annotated[ + str, + Field( + description='The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content.' + ), + ] + detail: Annotated[ + Optional[Literal["auto", "low", "high"]], + Field( + description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`." + ), + ] = "auto" + + +class MessageContentImageFileObject(BaseModel): + type: Annotated[Literal["image_file"], Field(description="Always `image_file`.")] + image_file: ImageFile + + +class ImageFile1(BaseModel): + file_id: Annotated[ + Optional[str], + Field( + description='The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content.' + ), + ] = None + detail: Annotated[ + Optional[Literal["auto", "low", "high"]], + Field( + description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`." + ), + ] = "auto" + + +class MessageDeltaContentImageFileObject(BaseModel): + index: Annotated[int, Field(description="The index of the content part in the message.")] + type: Annotated[Literal["image_file"], Field(description="Always `image_file`.")] + image_file: Optional[ImageFile1] = None + + +class ImageUrl1(BaseModel): + url: Annotated[ + AnyUrl, + Field( + description="The external URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." + ), + ] + detail: Annotated[ + Optional[Literal["auto", "low", "high"]], + Field( + description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto`" + ), + ] = "auto" + + +class MessageContentImageUrlObject(BaseModel): + type: Annotated[Literal["image_url"], Field(description="The type of the content part.")] + image_url: ImageUrl1 + + +class ImageUrl2(BaseModel): + url: Annotated[ + Optional[str], + Field( + description="The URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." + ), + ] = None + detail: Annotated[ + Optional[Literal["auto", "low", "high"]], + Field( + description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`." + ), + ] = "auto" + + +class MessageDeltaContentImageUrlObject(BaseModel): + index: Annotated[int, Field(description="The index of the content part in the message.")] + type: Annotated[Literal["image_url"], Field(description="Always `image_url`.")] + image_url: Optional[ImageUrl2] = None + + +class MessageContentRefusalObject(BaseModel): + type: Annotated[Literal["refusal"], Field(description="Always `refusal`.")] + refusal: str + + +class MessageRequestContentTextObject(BaseModel): + type: Annotated[Literal["text"], Field(description="Always `text`.")] + text: Annotated[str, Field(description="Text content to be sent to the model")] + + +class FileCitation(BaseModel): + file_id: Annotated[str, Field(description="The ID of the specific File the citation is from.")] + + +class MessageContentTextAnnotationsFileCitationObject(BaseModel): + type: Annotated[Literal["file_citation"], Field(description="Always `file_citation`.")] + text: Annotated[ + str, + Field(description="The text in the message content that needs to be replaced."), + ] + file_citation: FileCitation + start_index: Annotated[int, Field(ge=0)] + end_index: Annotated[int, Field(ge=0)] + + +class FilePath(BaseModel): + file_id: Annotated[str, Field(description="The ID of the file that was generated.")] + + +class MessageContentTextAnnotationsFilePathObject(BaseModel): + type: Annotated[Literal["file_path"], Field(description="Always `file_path`.")] + text: Annotated[ + str, + Field(description="The text in the message content that needs to be replaced."), + ] + file_path: FilePath + start_index: Annotated[int, Field(ge=0)] + end_index: Annotated[int, Field(ge=0)] + + +class MessageDeltaContentRefusalObject(BaseModel): + index: Annotated[int, Field(description="The index of the refusal part in the message.")] + type: Annotated[Literal["refusal"], Field(description="Always `refusal`.")] + refusal: Optional[str] = None + + +class FileCitation1(BaseModel): + file_id: Annotated[ + Optional[str], + Field(description="The ID of the specific File the citation is from."), + ] = None + quote: Annotated[Optional[str], Field(description="The specific quote in the file.")] = None + + +class MessageDeltaContentTextAnnotationsFileCitationObject(BaseModel): + index: Annotated[ + int, Field(description="The index of the annotation in the text content part.") + ] + type: Annotated[Literal["file_citation"], Field(description="Always `file_citation`.")] + text: Annotated[ + Optional[str], + Field(description="The text in the message content that needs to be replaced."), + ] = None + file_citation: Optional[FileCitation1] = None + start_index: Annotated[Optional[int], Field(ge=0)] = None + end_index: Annotated[Optional[int], Field(ge=0)] = None + + +class FilePath1(BaseModel): + file_id: Annotated[ + Optional[str], Field(description="The ID of the file that was generated.") + ] = None + + +class MessageDeltaContentTextAnnotationsFilePathObject(BaseModel): + index: Annotated[ + int, Field(description="The index of the annotation in the text content part.") + ] + type: Annotated[Literal["file_path"], Field(description="Always `file_path`.")] + text: Annotated[ + Optional[str], + Field(description="The text in the message content that needs to be replaced."), + ] = None + file_path: Optional[FilePath1] = None + start_index: Annotated[Optional[int], Field(ge=0)] = None + end_index: Annotated[Optional[int], Field(ge=0)] = None + + +class LastError1(BaseModel): + code: Annotated[ + Literal["server_error", "rate_limit_exceeded"], + Field(description="One of `server_error` or `rate_limit_exceeded`."), + ] + message: Annotated[str, Field(description="A human-readable description of the error.")] + + +class MessageCreation(BaseModel): + message_id: Annotated[ + str, + Field(description="The ID of the message that was created by this run step."), + ] + + +class RunStepDetailsMessageCreationObject(BaseModel): + type: Annotated[Literal["message_creation"], Field(description="Always `message_creation`.")] + message_creation: MessageCreation + + +class MessageCreation1(BaseModel): + message_id: Annotated[ + Optional[str], + Field(description="The ID of the message that was created by this run step."), + ] = None + + +class RunStepDeltaStepDetailsMessageCreationObject(BaseModel): + type: Annotated[Literal["message_creation"], Field(description="Always `message_creation`.")] + message_creation: Optional[MessageCreation1] = None + + +class RunStepDetailsToolCallsCodeOutputLogsObject(BaseModel): + type: Annotated[Literal["logs"], Field(description="Always `logs`.")] + logs: Annotated[str, Field(description="The text output from the Code Interpreter tool call.")] + + +class RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject(BaseModel): + index: Annotated[int, Field(description="The index of the output in the outputs array.")] + type: Annotated[Literal["logs"], Field(description="Always `logs`.")] + logs: Annotated[ + Optional[str], + Field(description="The text output from the Code Interpreter tool call."), + ] = None + + +class Image1(BaseModel): + file_id: Annotated[ + str, Field(description="The [file](/docs/api-reference/files) ID of the image.") + ] + + +class RunStepDetailsToolCallsCodeOutputImageObject(BaseModel): + type: Annotated[Literal["image"], Field(description="Always `image`.")] + image: Image1 + + +class Image2(BaseModel): + file_id: Annotated[ + Optional[str], + Field(description="The [file](/docs/api-reference/files) ID of the image."), + ] = None + + +class RunStepDeltaStepDetailsToolCallsCodeOutputImageObject(BaseModel): + index: Annotated[int, Field(description="The index of the output in the outputs array.")] + type: Annotated[Literal["image"], Field(description="Always `image`.")] + image: Optional[Image2] = None + + +class RunStepDetailsToolCallsFileSearchObject(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call object.")] + type: Annotated[ + Literal["file_search"], + Field( + description="The type of tool call. This is always going to be `file_search` for this type of tool call." + ), + ] + file_search: Annotated[ + Dict[str, Any], + Field(description="For now, this is always going to be an empty object."), + ] + + +class RunStepDeltaStepDetailsToolCallsFileSearchObject(BaseModel): + index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call object.")] = None + type: Annotated[ + Literal["file_search"], + Field( + description="The type of tool call. This is always going to be `file_search` for this type of tool call." + ), + ] + file_search: Annotated[ + Dict[str, Any], + Field(description="For now, this is always going to be an empty object."), + ] + + +class Function5(BaseModel): + name: Annotated[str, Field(description="The name of the function.")] + arguments: Annotated[str, Field(description="The arguments passed to the function.")] + output: Annotated[ + str, + Field( + description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet." + ), + ] + + +class RunStepDetailsToolCallsFunctionObject(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call object.")] + type: Annotated[ + Literal["function"], + Field( + description="The type of tool call. This is always going to be `function` for this type of tool call." + ), + ] + function: Annotated[ + Function5, Field(description="The definition of the function that was called.") + ] + + +class Function6(BaseModel): + name: Annotated[Optional[str], Field(description="The name of the function.")] = None + arguments: Annotated[ + Optional[str], Field(description="The arguments passed to the function.") + ] = None + output: Annotated[ + Optional[str], + Field( + description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet." + ), + ] = None + + +class RunStepDeltaStepDetailsToolCallsFunctionObject(BaseModel): + index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call object.")] = None + type: Annotated[ + Literal["function"], + Field( + description="The type of tool call. This is always going to be `function` for this type of tool call." + ), + ] + function: Annotated[ + Optional[Function6], + Field(description="The definition of the function that was called."), + ] = None + + +class VectorStoreExpirationAfter(BaseModel): + anchor: Annotated[ + Literal["last_active_at"], + Field( + description="Anchor timestamp after which the expiration policy applies. Supported anchors: `last_active_at`." + ), + ] + days: Annotated[ + int, + Field( + description="The number of days after the anchor time that the vector store will expire.", + ge=1, + le=365, + ), + ] + + +class FileCounts(BaseModel): + in_progress: Annotated[ + int, + Field(description="The number of files that are currently being processed."), + ] + completed: Annotated[ + int, + Field(description="The number of files that have been successfully processed."), + ] + failed: Annotated[int, Field(description="The number of files that have failed to process.")] + cancelled: Annotated[int, Field(description="The number of files that were cancelled.")] + total: Annotated[int, Field(description="The total number of files.")] + + +class VectorStoreObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["vector_store"], + Field(description="The object type, which is always `vector_store`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the vector store was created."), + ] + name: Annotated[str, Field(description="The name of the vector store.")] + usage_bytes: Annotated[ + int, + Field(description="The total number of bytes used by the files in the vector store."), + ] + file_counts: FileCounts + status: Annotated[ + Literal["expired", "in_progress", "completed"], + Field( + description="The status of the vector store, which can be either `expired`, `in_progress`, or `completed`. A status of `completed` indicates that the vector store is ready for use." + ), + ] + expires_after: Optional[VectorStoreExpirationAfter] = None + expires_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the vector store will expire."), + ] = None + last_active_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the vector store was last active." + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + + +class UpdateVectorStoreRequest(BaseModel): + class Config: + extra = Extra.forbid + + name: Annotated[Optional[str], Field(description="The name of the vector store.")] = None + expires_after: Optional[VectorStoreExpirationAfter] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class ListVectorStoresResponse(BaseModel): + object: Annotated[str, Field(example="list")] + data: List[VectorStoreObject] + first_id: Annotated[str, Field(example="vs_abc123")] + last_id: Annotated[str, Field(example="vs_abc456")] + has_more: Annotated[bool, Field(example=False)] + + +class DeleteVectorStoreResponse(BaseModel): + id: str + deleted: bool + object: Literal["vector_store.deleted"] + + +class LastError2(BaseModel): + code: Annotated[ + Literal["server_error", "unsupported_file", "invalid_file"], + Field(description="One of `server_error` or `rate_limit_exceeded`."), + ] + message: Annotated[str, Field(description="A human-readable description of the error.")] + + +class OtherChunkingStrategyResponseParam(BaseModel): + class Config: + extra = Extra.forbid + + type: Annotated[Literal["other"], Field(description="Always `other`.")] + + +class StaticChunkingStrategy(BaseModel): + class Config: + extra = Extra.forbid + + max_chunk_size_tokens: Annotated[ + int, + Field( + description="The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`.", + ge=100, + le=4096, + ), + ] + chunk_overlap_tokens: Annotated[ + int, + Field( + description="The number of tokens that overlap between chunks. The default value is `400`.\n\nNote that the overlap must not exceed half of `max_chunk_size_tokens`.\n" + ), + ] + + +class AutoChunkingStrategyRequestParam(BaseModel): + class Config: + extra = Extra.forbid + + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class StaticChunkingStrategyRequestParam(BaseModel): + class Config: + extra = Extra.forbid + + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: StaticChunkingStrategy + + +class ChunkingStrategyRequestParam(BaseModel): + __root__: Annotated[ + Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." + ), + ] + + +class CreateVectorStoreFileRequest(BaseModel): + class Config: + extra = Extra.forbid + + file_id: Annotated[ + str, + Field( + description="A [File](/docs/api-reference/files) ID that the vector store should use. Useful for tools like `file_search` that can access files." + ), + ] + chunking_strategy: Optional[ChunkingStrategyRequestParam] = None + + +class DeleteVectorStoreFileResponse(BaseModel): + id: str + deleted: bool + object: Literal["vector_store.file.deleted"] + + +class FileCounts1(BaseModel): + in_progress: Annotated[ + int, + Field(description="The number of files that are currently being processed."), + ] + completed: Annotated[int, Field(description="The number of files that have been processed.")] + failed: Annotated[int, Field(description="The number of files that have failed to process.")] + cancelled: Annotated[int, Field(description="The number of files that where cancelled.")] + total: Annotated[int, Field(description="The total number of files.")] + + +class VectorStoreFileBatchObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["vector_store.files_batch"], + Field(description="The object type, which is always `vector_store.file_batch`."), + ] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the vector store files batch was created." + ), + ] + vector_store_id: Annotated[ + str, + Field( + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to." + ), + ] + status: Annotated[ + Literal["in_progress", "completed", "cancelled", "failed"], + Field( + description="The status of the vector store files batch, which can be either `in_progress`, `completed`, `cancelled` or `failed`." + ), + ] + file_counts: FileCounts1 + + +class CreateVectorStoreFileBatchRequest(BaseModel): + class Config: + extra = Extra.forbid + + file_ids: Annotated[ + List[str], + Field( + description="A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files.", + max_items=500, + min_items=1, + ), + ] + chunking_strategy: Optional[ChunkingStrategyRequestParam] = None + + +class ThreadStreamEvent1(BaseModel): + event: Literal["thread.created"] + data: ThreadObject + + +class ThreadStreamEvent(BaseModel): + __root__: ThreadStreamEvent1 + + +class ErrorEvent(BaseModel): + event: Literal["error"] + data: Error + + +class DoneEvent(BaseModel): + event: Literal["done"] + data: Literal["[DONE]"] + + +class Datum(BaseModel): + code: Annotated[ + Optional[str], Field(description="An error code identifying the error type.") + ] = None + message: Annotated[ + Optional[str], + Field(description="A human-readable message providing more details about the error."), + ] = None + param: Annotated[ + Optional[str], + Field(description="The name of the parameter that caused the error, if applicable."), + ] = None + line: Annotated[ + Optional[int], + Field( + description="The line number of the input file where the error occurred, if applicable." + ), + ] = None + + +class Errors(BaseModel): + object: Annotated[ + Optional[str], Field(description="The object type, which is always `list`.") + ] = None + data: Optional[List[Datum]] = None + + +class RequestCounts(BaseModel): + total: Annotated[int, Field(description="Total number of requests in the batch.")] + completed: Annotated[ + int, + Field(description="Number of requests that have been completed successfully."), + ] + failed: Annotated[int, Field(description="Number of requests that have failed.")] + + +class Batch(BaseModel): + id: str + object: Annotated[ + Literal["batch"], Field(description="The object type, which is always `batch`.") + ] + endpoint: Annotated[str, Field(description="The OpenAI API endpoint used by the batch.")] + errors: Optional[Errors] = None + input_file_id: Annotated[str, Field(description="The ID of the input file for the batch.")] + completion_window: Annotated[ + str, + Field(description="The time frame within which the batch should be processed."), + ] + status: Annotated[ + Literal[ + "validating", + "failed", + "in_progress", + "finalizing", + "completed", + "expired", + "cancelling", + "cancelled", + ], + Field(description="The current status of the batch."), + ] + output_file_id: Annotated[ + Optional[str], + Field( + description="The ID of the file containing the outputs of successfully executed requests." + ), + ] = None + error_file_id: Annotated[ + Optional[str], + Field(description="The ID of the file containing the outputs of requests with errors."), + ] = None + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the batch was created."), + ] + in_progress_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch started processing."), + ] = None + expires_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch will expire."), + ] = None + finalizing_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch started finalizing."), + ] = None + completed_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch was completed."), + ] = None + failed_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch failed."), + ] = None + expired_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch expired."), + ] = None + cancelling_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch started cancelling."), + ] = None + cancelled_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch was cancelled."), + ] = None + request_counts: Annotated[ + Optional[RequestCounts], + Field(description="The request counts for different statuses within the batch."), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class BatchRequestInput(BaseModel): + custom_id: Annotated[ + Optional[str], + Field( + description="A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch." + ), + ] = None + method: Annotated[ + Optional[Literal["POST"]], + Field( + description="The HTTP method to be used for the request. Currently only `POST` is supported." + ), + ] = None + url: Annotated[ + Optional[str], + Field( + description="The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported." + ), + ] = None + + +class Response(BaseModel): + status_code: Annotated[ + Optional[int], Field(description="The HTTP status code of the response") + ] = None + request_id: Annotated[ + Optional[str], + Field( + description="An unique identifier for the OpenAI API request. Please include this request ID when contacting support." + ), + ] = None + body: Annotated[ + Optional[Dict[str, Any]], Field(description="The JSON body of the response") + ] = None + + +class Error2(BaseModel): + code: Annotated[Optional[str], Field(description="A machine-readable error code.")] = None + message: Annotated[Optional[str], Field(description="A human-readable error message.")] = None + + +class BatchRequestOutput(BaseModel): + id: Optional[str] = None + custom_id: Annotated[ + Optional[str], + Field( + description="A developer-provided per-request id that will be used to match outputs to inputs." + ), + ] = None + response: Optional[Response] = None + error: Annotated[ + Optional[Error2], + Field( + description="For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure." + ), + ] = None + + +class ListBatchesResponse(BaseModel): + data: List[Batch] + first_id: Annotated[Optional[str], Field(example="batch_abc123")] = None + last_id: Annotated[Optional[str], Field(example="batch_abc456")] = None + has_more: bool + object: Literal["list"] + + +class AuditLogActorServiceAccount(BaseModel): + id: Annotated[Optional[str], Field(description="The service account id.")] = None + + +class AuditLogActorUser(BaseModel): + id: Annotated[Optional[str], Field(description="The user id.")] = None + email: Annotated[Optional[str], Field(description="The user email.")] = None + + +class AuditLogActorApiKey(BaseModel): + id: Annotated[Optional[str], Field(description="The tracking id of the API key.")] = None + type: Annotated[ + Optional[Literal["user", "service_account"]], + Field(description="The type of API key. Can be either `user` or `service_account`."), + ] = None + user: Optional[AuditLogActorUser] = None + service_account: Optional[AuditLogActorServiceAccount] = None + + +class AuditLogActorSession(BaseModel): + user: Optional[AuditLogActorUser] = None + ip_address: Annotated[ + Optional[str], + Field(description="The IP address from which the action was performed."), + ] = None + + +class AuditLogActor(BaseModel): + type: Annotated[ + Optional[Literal["session", "api_key"]], + Field(description="The type of actor. Is either `session` or `api_key`."), + ] = None + session: Optional[AuditLogActorSession] = None + api_key: Optional[AuditLogActorApiKey] = None + + +class AuditLogEventType(BaseModel): + __root__: Annotated[ + Literal[ + "api_key.created", + "api_key.updated", + "api_key.deleted", + "invite.sent", + "invite.accepted", + "invite.deleted", + "login.succeeded", + "login.failed", + "logout.succeeded", + "logout.failed", + "organization.updated", + "project.created", + "project.updated", + "project.archived", + "service_account.created", + "service_account.updated", + "service_account.deleted", + "user.added", + "user.updated", + "user.deleted", + ], + Field(description="The event type."), + ] + + +class Project(BaseModel): + id: Annotated[Optional[str], Field(description="The project ID.")] = None + name: Annotated[Optional[str], Field(description="The project title.")] = None + + +class Data(BaseModel): + scopes: Annotated[ + Optional[List[str]], + Field(description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`'), + ] = None + + +class ApiKeyCreated(BaseModel): + id: Annotated[Optional[str], Field(description="The tracking ID of the API key.")] = None + data: Annotated[ + Optional[Data], Field(description="The payload used to create the API key.") + ] = None + + +class ChangesRequested(BaseModel): + scopes: Annotated[ + Optional[List[str]], + Field(description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`'), + ] = None + + +class ApiKeyUpdated(BaseModel): + id: Annotated[Optional[str], Field(description="The tracking ID of the API key.")] = None + changes_requested: Annotated[ + Optional[ChangesRequested], + Field(description="The payload used to update the API key."), + ] = None + + +class ApiKeyDeleted(BaseModel): + id: Annotated[Optional[str], Field(description="The tracking ID of the API key.")] = None + + +class Data1(BaseModel): + email: Annotated[ + Optional[str], Field(description="The email invited to the organization.") + ] = None + role: Annotated[ + Optional[str], + Field(description="The role the email was invited to be. Is either `owner` or `member`."), + ] = None + + +class InviteSent(BaseModel): + id: Annotated[Optional[str], Field(description="The ID of the invite.")] = None + data: Annotated[ + Optional[Data1], Field(description="The payload used to create the invite.") + ] = None + + +class InviteAccepted(BaseModel): + id: Annotated[Optional[str], Field(description="The ID of the invite.")] = None + + +class InviteDeleted(BaseModel): + id: Annotated[Optional[str], Field(description="The ID of the invite.")] = None + + +class LoginFailed(BaseModel): + error_code: Annotated[Optional[str], Field(description="The error code of the failure.")] = None + error_message: Annotated[ + Optional[str], Field(description="The error message of the failure.") + ] = None + + +class LogoutFailed(BaseModel): + error_code: Annotated[Optional[str], Field(description="The error code of the failure.")] = None + error_message: Annotated[ + Optional[str], Field(description="The error message of the failure.") + ] = None + + +class Settings(BaseModel): + threads_ui_visibility: Annotated[ + Optional[str], + Field( + description="Visibility of the threads page which shows messages created with the Assistants API and Playground. One of `ANY_ROLE`, `OWNERS`, or `NONE`." + ), + ] = None + usage_dashboard_visibility: Annotated[ + Optional[str], + Field( + description="Visibility of the usage dashboard which shows activity and costs for your organization. One of `ANY_ROLE` or `OWNERS`." + ), + ] = None + + +class ChangesRequested1(BaseModel): + title: Annotated[Optional[str], Field(description="The organization title.")] = None + description: Annotated[Optional[str], Field(description="The organization description.")] = None + name: Annotated[Optional[str], Field(description="The organization name.")] = None + settings: Optional[Settings] = None + + +class OrganizationUpdated(BaseModel): + id: Annotated[Optional[str], Field(description="The organization ID.")] = None + changes_requested: Annotated[ + Optional[ChangesRequested1], + Field(description="The payload used to update the organization settings."), + ] = None + + +class Data2(BaseModel): + name: Annotated[Optional[str], Field(description="The project name.")] = None + title: Annotated[ + Optional[str], + Field(description="The title of the project as seen on the dashboard."), + ] = None + + +class ProjectCreated(BaseModel): + id: Annotated[Optional[str], Field(description="The project ID.")] = None + data: Annotated[ + Optional[Data2], Field(description="The payload used to create the project.") + ] = None + + +class ChangesRequested2(BaseModel): + title: Annotated[ + Optional[str], + Field(description="The title of the project as seen on the dashboard."), + ] = None + + +class ProjectUpdated(BaseModel): + id: Annotated[Optional[str], Field(description="The project ID.")] = None + changes_requested: Annotated[ + Optional[ChangesRequested2], + Field(description="The payload used to update the project."), + ] = None + + +class ProjectArchived(BaseModel): + id: Annotated[Optional[str], Field(description="The project ID.")] = None + + +class Data3(BaseModel): + role: Annotated[ + Optional[str], + Field(description="The role of the service account. Is either `owner` or `member`."), + ] = None + + +class ServiceAccountCreated(BaseModel): + id: Annotated[Optional[str], Field(description="The service account ID.")] = None + data: Annotated[ + Optional[Data3], + Field(description="The payload used to create the service account."), + ] = None + + +class ChangesRequested3(BaseModel): + role: Annotated[ + Optional[str], + Field(description="The role of the service account. Is either `owner` or `member`."), + ] = None + + +class ServiceAccountUpdated(BaseModel): + id: Annotated[Optional[str], Field(description="The service account ID.")] = None + changes_requested: Annotated[ + Optional[ChangesRequested3], + Field(description="The payload used to updated the service account."), + ] = None + + +class ServiceAccountDeleted(BaseModel): + id: Annotated[Optional[str], Field(description="The service account ID.")] = None + + +class Data4(BaseModel): + role: Annotated[ + Optional[str], + Field(description="The role of the user. Is either `owner` or `member`."), + ] = None + + +class UserAdded(BaseModel): + id: Annotated[Optional[str], Field(description="The user ID.")] = None + data: Annotated[ + Optional[Data4], + Field(description="The payload used to add the user to the project."), + ] = None + + +class ChangesRequested4(BaseModel): + role: Annotated[ + Optional[str], + Field(description="The role of the user. Is either `owner` or `member`."), + ] = None + + +class UserUpdated(BaseModel): + id: Annotated[Optional[str], Field(description="The project ID.")] = None + changes_requested: Annotated[ + Optional[ChangesRequested4], + Field(description="The payload used to update the user."), + ] = None + + +class UserDeleted(BaseModel): + id: Annotated[Optional[str], Field(description="The user ID.")] = None + + +class AuditLog(BaseModel): + id: Annotated[str, Field(description="The ID of this log.")] + type: AuditLogEventType + effective_at: Annotated[int, Field(description="The Unix timestamp (in seconds) of the event.")] + project: Annotated[ + Optional[Project], + Field( + description="The project that the action was scoped to. Absent for actions not scoped to projects." + ), + ] = None + actor: AuditLogActor + api_key_created: Annotated[ + Optional[ApiKeyCreated], + Field( + alias="api_key.created", + description="The details for events with this `type`.", + ), + ] = None + api_key_updated: Annotated[ + Optional[ApiKeyUpdated], + Field( + alias="api_key.updated", + description="The details for events with this `type`.", + ), + ] = None + api_key_deleted: Annotated[ + Optional[ApiKeyDeleted], + Field( + alias="api_key.deleted", + description="The details for events with this `type`.", + ), + ] = None + invite_sent: Annotated[ + Optional[InviteSent], + Field(alias="invite.sent", description="The details for events with this `type`."), + ] = None + invite_accepted: Annotated[ + Optional[InviteAccepted], + Field( + alias="invite.accepted", + description="The details for events with this `type`.", + ), + ] = None + invite_deleted: Annotated[ + Optional[InviteDeleted], + Field( + alias="invite.deleted", + description="The details for events with this `type`.", + ), + ] = None + login_failed: Annotated[ + Optional[LoginFailed], + Field(alias="login.failed", description="The details for events with this `type`."), + ] = None + logout_failed: Annotated[ + Optional[LogoutFailed], + Field( + alias="logout.failed", + description="The details for events with this `type`.", + ), + ] = None + organization_updated: Annotated[ + Optional[OrganizationUpdated], + Field( + alias="organization.updated", + description="The details for events with this `type`.", + ), + ] = None + project_created: Annotated[ + Optional[ProjectCreated], + Field( + alias="project.created", + description="The details for events with this `type`.", + ), + ] = None + project_updated: Annotated[ + Optional[ProjectUpdated], + Field( + alias="project.updated", + description="The details for events with this `type`.", + ), + ] = None + project_archived: Annotated[ + Optional[ProjectArchived], + Field( + alias="project.archived", + description="The details for events with this `type`.", + ), + ] = None + service_account_created: Annotated[ + Optional[ServiceAccountCreated], + Field( + alias="service_account.created", + description="The details for events with this `type`.", + ), + ] = None + service_account_updated: Annotated[ + Optional[ServiceAccountUpdated], + Field( + alias="service_account.updated", + description="The details for events with this `type`.", + ), + ] = None + service_account_deleted: Annotated[ + Optional[ServiceAccountDeleted], + Field( + alias="service_account.deleted", + description="The details for events with this `type`.", + ), + ] = None + user_added: Annotated[ + Optional[UserAdded], + Field(alias="user.added", description="The details for events with this `type`."), + ] = None + user_updated: Annotated[ + Optional[UserUpdated], + Field(alias="user.updated", description="The details for events with this `type`."), + ] = None + user_deleted: Annotated[ + Optional[UserDeleted], + Field(alias="user.deleted", description="The details for events with this `type`."), + ] = None + + +class ListAuditLogsResponse(BaseModel): + object: Literal["list"] + data: List[AuditLog] + first_id: Annotated[str, Field(example="audit_log-defb456h8dks")] + last_id: Annotated[str, Field(example="audit_log-hnbkd8s93s")] + has_more: bool + + +class Invite(BaseModel): + object: Annotated[ + Literal["organization.invite"], + Field(description="The object type, which is always `organization.invite`"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + email: Annotated[ + str, + Field(description="The email address of the individual to whom the invite was sent"), + ] + role: Annotated[Literal["owner", "reader"], Field(description="`owner` or `reader`")] + status: Annotated[ + Literal["accepted", "expired", "pending"], + Field(description="`accepted`,`expired`, or `pending`"), + ] + invited_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the invite was sent."), + ] + expires_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the invite expires."), + ] + accepted_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) of when the invite was accepted."), + ] = None + + +class InviteListResponse(BaseModel): + object: Annotated[Literal["list"], Field(description="The object type, which is always `list`")] + data: List[Invite] + first_id: Annotated[ + Optional[str], + Field(description="The first `invite_id` in the retrieved `list`"), + ] = None + last_id: Annotated[ + Optional[str], Field(description="The last `invite_id` in the retrieved `list`") + ] = None + has_more: Annotated[ + Optional[bool], + Field( + description="The `has_more` property is used for pagination to indicate there are additional results." + ), + ] = None + + +class InviteRequest(BaseModel): + email: Annotated[str, Field(description="Send an email to this address")] + role: Annotated[Literal["reader", "owner"], Field(description="`owner` or `reader`")] + + +class InviteDeleteResponse(BaseModel): + object: Annotated[ + Literal["organization.invite.deleted"], + Field(description="The object type, which is always `organization.invite.deleted`"), + ] + id: str + deleted: bool + + +class User(BaseModel): + object: Annotated[ + Literal["organization.user"], + Field(description="The object type, which is always `organization.user`"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + name: Annotated[str, Field(description="The name of the user")] + email: Annotated[str, Field(description="The email address of the user")] + role: Annotated[Literal["owner", "reader"], Field(description="`owner` or `reader`")] + added_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the user was added."), + ] + + +class UserListResponse(BaseModel): + object: Literal["list"] + data: List[User] + first_id: str + last_id: str + has_more: bool + + +class UserRoleUpdateRequest(BaseModel): + role: Annotated[Literal["owner", "reader"], Field(description="`owner` or `reader`")] + + +class UserDeleteResponse(BaseModel): + object: Literal["organization.user.deleted"] + id: str + deleted: bool + + +class Project1(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + object: Annotated[ + Literal["organization.project"], + Field(description="The object type, which is always `organization.project`"), + ] + name: Annotated[str, Field(description="The name of the project. This appears in reporting.")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the project was created."), + ] + archived_at: Annotated[ + Optional[int], + Field( + description="The Unix timestamp (in seconds) of when the project was archived or `null`." + ), + ] = None + status: Annotated[Literal["active", "archived"], Field(description="`active` or `archived`")] + + +class ProjectListResponse(BaseModel): + object: Literal["list"] + data: List[Project1] + first_id: str + last_id: str + has_more: bool + + +class ProjectCreateRequest(BaseModel): + name: Annotated[ + str, + Field(description="The friendly name of the project, this name appears in reports."), + ] + + +class ProjectUpdateRequest(BaseModel): + name: Annotated[ + str, + Field(description="The updated name of the project, this name appears in reports."), + ] + + +class DefaultProjectErrorResponse(BaseModel): + code: int + message: str + + +class ProjectUser(BaseModel): + object: Annotated[ + Literal["organization.project.user"], + Field(description="The object type, which is always `organization.project.user`"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + name: Annotated[str, Field(description="The name of the user")] + email: Annotated[str, Field(description="The email address of the user")] + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + added_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the project was added."), + ] + + +class ProjectUserListResponse(BaseModel): + object: str + data: List[ProjectUser] + first_id: str + last_id: str + has_more: bool + + +class ProjectUserCreateRequest(BaseModel): + user_id: Annotated[str, Field(description="The ID of the user.")] + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + + +class ProjectUserUpdateRequest(BaseModel): + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + + +class ProjectUserDeleteResponse(BaseModel): + object: Literal["organization.project.user.deleted"] + id: str + deleted: bool + + +class ProjectServiceAccount(BaseModel): + object: Annotated[ + Literal["organization.project.service_account"], + Field( + description="The object type, which is always `organization.project.service_account`" + ), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + name: Annotated[str, Field(description="The name of the service account")] + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the service account was created" + ), + ] + + +class ProjectServiceAccountListResponse(BaseModel): + object: Literal["list"] + data: List[ProjectServiceAccount] + first_id: str + last_id: str + has_more: bool + + +class ProjectServiceAccountCreateRequest(BaseModel): + name: Annotated[str, Field(description="The name of the service account being created.")] + + +class ProjectServiceAccountApiKey(BaseModel): + object: Annotated[ + Literal["organization.project.service_account.api_key"], + Field( + description="The object type, which is always `organization.project.service_account.api_key`" + ), + ] + value: str + name: str + created_at: int + id: str + + +class ProjectServiceAccountDeleteResponse(BaseModel): + object: Literal["organization.project.service_account.deleted"] + id: str + deleted: bool + + +class Owner(BaseModel): + type: Annotated[ + Optional[Literal["user", "service_account"]], + Field(description="`user` or `service_account`"), + ] = None + user: Optional[ProjectUser] = None + service_account: Optional[ProjectServiceAccount] = None + + +class ProjectApiKey(BaseModel): + object: Annotated[ + Literal["organization.project.api_key"], + Field(description="The object type, which is always `organization.project.api_key`"), + ] + redacted_value: Annotated[str, Field(description="The redacted value of the API key")] + name: Annotated[str, Field(description="The name of the API key")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the API key was created"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + owner: Owner + + +class ProjectApiKeyListResponse(BaseModel): + object: Literal["list"] + data: List[ProjectApiKey] + first_id: str + last_id: str + has_more: bool + + +class ProjectApiKeyDeleteResponse(BaseModel): + object: Literal["organization.project.api_key.deleted"] + id: str + deleted: bool + + +class ListModelsResponse(BaseModel): + object: Literal["list"] + data: List[Model] + + +class CreateCompletionRequest(BaseModel): + model: Annotated[ + Union[str, Literal["gpt-3.5-turbo-instruct", "davinci-002", "babbage-002"]], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n" + ), + ] + prompt: Annotated[ + Union[str, List[str], Prompt, Prompt1], + Field( + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n" + ), + ] + best_of: Annotated[ + Optional[int], + Field( + description='Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed.\n\nWhen used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n', + ge=0, + le=20, + ), + ] = 1 + echo: Annotated[ + Optional[bool], + Field(description="Echo back the prompt in addition to the completion\n"), + ] = False + frequency_penalty: Annotated[ + Optional[float], + Field( + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] = 0 + logit_bias: Annotated[ + Optional[Dict[str, int]], + Field( + description='Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n\nAs an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.\n' + ), + ] = None + logprobs: Annotated[ + Optional[int], + Field( + description="Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response.\n\nThe maximum value for `logprobs` is 5.\n", + ge=0, + le=5, + ), + ] = None + max_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of [tokens](/tokenizer) that can be generated in the completion.\n\nThe token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + example=16, + ge=0, + ), + ] = 16 + n: Annotated[ + Optional[int], + Field( + description="How many completions to generate for each prompt.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n", + example=1, + ge=1, + le=128, + ), + ] = 1 + presence_penalty: Annotated[ + Optional[float], + Field( + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] = 0 + seed: Annotated[ + Optional[int], + Field( + description="If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\n\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", + ge=-9223372036854775808, + le=9223372036854775807, + ), + ] = None + stop: Annotated[ + Optional[Union[str, Stop]], + Field( + description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n" + ), + ] = None + stream: Annotated[ + Optional[bool], + Field( + description="Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n" + ), + ] = False + stream_options: Optional[ChatCompletionStreamOptions] = None + suffix: Annotated[ + Optional[str], + Field( + description="The suffix that comes after a completion of inserted text.\n\nThis parameter is only supported for `gpt-3.5-turbo-instruct`.\n", + example="test.", + ), + ] = None + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", + example=1, + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", + example=1, + ge=0.0, + le=1.0, + ), + ] = 1 + user: Annotated[ + Optional[str], + Field( + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + example="user-1234", + ), + ] = None + + +class CreateCompletionResponse(BaseModel): + id: Annotated[str, Field(description="A unique identifier for the completion.")] + choices: Annotated[ + List[Choice], + Field( + description="The list of completion choices the model generated for the input prompt." + ), + ] + created: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the completion was created."), + ] + model: Annotated[str, Field(description="The model used for completion.")] + system_fingerprint: Annotated[ + Optional[str], + Field( + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" + ), + ] = None + object: Annotated[ + Literal["text_completion"], + Field(description='The object type, which is always "text_completion"'), + ] + usage: Optional[CompletionUsage] = None + + +class ChatCompletionTool(BaseModel): + type: Annotated[ + Literal["function"], + Field(description="The type of the tool. Currently, only `function` is supported."), + ] + function: FunctionObject + + +class ChatCompletionToolChoiceOption(BaseModel): + __root__: Annotated[ + Union[Literal["none", "auto", "required"], ChatCompletionNamedToolChoice], + Field( + description='Controls which (if any) tool is called by the model.\n`none` means the model will not call any tool and instead generates a message.\n`auto` means the model can pick between generating a message or calling one or more tools.\n`required` means the model must call one or more tools.\nSpecifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.\n\n`none` is the default when no tools are present. `auto` is the default if tools are present.\n' + ), + ] + + +class ChatCompletionMessageToolCalls(BaseModel): + __root__: Annotated[ + List[ChatCompletionMessageToolCall], + Field(description="The tool calls generated by the model, such as function calls."), + ] + + +class ChatCompletionResponseMessage(BaseModel): + content: Annotated[str, Field(description="The contents of the message.")] + refusal: Annotated[str, Field(description="The refusal message generated by the model.")] + tool_calls: Optional[ChatCompletionMessageToolCalls] = None + role: Annotated[ + Literal["assistant"], + Field(description="The role of the author of this message."), + ] + function_call: Annotated[ + Optional[FunctionCall], + Field( + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." + ), + ] = None + + +class Choice1(BaseModel): + finish_reason: Annotated[ + Literal["stop", "length", "tool_calls", "content_filter", "function_call"], + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n" + ), + ] + index: Annotated[int, Field(description="The index of the choice in the list of choices.")] + message: ChatCompletionResponseMessage + logprobs: Annotated[Logprobs2, Field(description="Log probability information for the choice.")] + + +class CreateChatCompletionResponse(BaseModel): + id: Annotated[str, Field(description="A unique identifier for the chat completion.")] + choices: Annotated[ + List[Choice1], + Field( + description="A list of chat completion choices. Can be more than one if `n` is greater than 1." + ), + ] + created: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the chat completion was created." + ), + ] + model: Annotated[str, Field(description="The model used for the chat completion.")] + service_tier: Annotated[ + Optional[Literal["scale", "default"]], + Field( + description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", + example="scale", + ), + ] = None + system_fingerprint: Annotated[ + Optional[str], + Field( + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" + ), + ] = None + object: Annotated[ + Literal["chat.completion"], + Field(description="The object type, which is always `chat.completion`."), + ] + usage: Optional[CompletionUsage] = None + + +class Choice2(BaseModel): + finish_reason: Annotated[ + Literal["stop", "length", "function_call", "content_filter"], + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, or `function_call` if the model called a function.\n" + ), + ] + index: Annotated[int, Field(description="The index of the choice in the list of choices.")] + message: ChatCompletionResponseMessage + + +class CreateChatCompletionFunctionResponse(BaseModel): + id: Annotated[str, Field(description="A unique identifier for the chat completion.")] + choices: Annotated[ + List[Choice2], + Field( + description="A list of chat completion choices. Can be more than one if `n` is greater than 1." + ), + ] + created: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the chat completion was created." + ), + ] + model: Annotated[str, Field(description="The model used for the chat completion.")] + system_fingerprint: Annotated[ + Optional[str], + Field( + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" + ), + ] = None + object: Annotated[ + Literal["chat.completion"], + Field(description="The object type, which is always `chat.completion`."), + ] + usage: Optional[CompletionUsage] = None + + +class ImagesResponse(BaseModel): + created: int + data: List[Image] + + +class ListFilesResponse(BaseModel): + data: List[OpenAIFile] + object: Literal["list"] + + +class ListFineTuningJobEventsResponse(BaseModel): + data: List[FineTuningJobEvent] + object: Literal["list"] + + +class ListFineTuningJobCheckpointsResponse(BaseModel): + data: List[FineTuningJobCheckpoint] + object: Literal["list"] + first_id: Optional[str] = None + last_id: Optional[str] = None + has_more: bool + + +class CreateEmbeddingResponse(BaseModel): + data: Annotated[ + List[Embedding], + Field(description="The list of embeddings generated by the model."), + ] + model: Annotated[ + str, Field(description="The name of the model used to generate the embedding.") + ] + object: Annotated[ + Literal["list"], Field(description='The object type, which is always "list".') + ] + usage: Annotated[Usage1, Field(description="The usage information for the request.")] + + +class FineTuningJob(BaseModel): + id: Annotated[ + str, + Field(description="The object identifier, which can be referenced in the API endpoints."), + ] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the fine-tuning job was created." + ), + ] + error: Annotated[ + Error1, + Field( + description="For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure." + ), + ] + fine_tuned_model: Annotated[ + str, + Field( + description="The name of the fine-tuned model that is being created. The value will be null if the fine-tuning job is still running." + ), + ] + finished_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the fine-tuning job was finished. The value will be null if the fine-tuning job is still running." + ), + ] + hyperparameters: Annotated[ + Hyperparameters1, + Field( + description="The hyperparameters used for the fine-tuning job. See the [fine-tuning guide](/docs/guides/fine-tuning) for more details." + ), + ] + model: Annotated[str, Field(description="The base model that is being fine-tuned.")] + object: Annotated[ + Literal["fine_tuning.job"], + Field(description='The object type, which is always "fine_tuning.job".'), + ] + organization_id: Annotated[ + str, Field(description="The organization that owns the fine-tuning job.") + ] + result_files: Annotated[ + List[str], + Field( + description="The compiled results file ID(s) for the fine-tuning job. You can retrieve the results with the [Files API](/docs/api-reference/files/retrieve-contents)." + ), + ] + status: Annotated[ + Literal["validating_files", "queued", "running", "succeeded", "failed", "cancelled"], + Field( + description="The current status of the fine-tuning job, which can be either `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`." + ), + ] + trained_tokens: Annotated[ + int, + Field( + description="The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running." + ), + ] + training_file: Annotated[ + str, + Field( + description="The file ID used for training. You can retrieve the training data with the [Files API](/docs/api-reference/files/retrieve-contents)." + ), + ] + validation_file: Annotated[ + str, + Field( + description="The file ID used for validation. You can retrieve the validation results with the [Files API](/docs/api-reference/files/retrieve-contents)." + ), + ] + integrations: Annotated[ + Optional[List[FineTuningIntegration]], + Field( + description="A list of integrations to enable for this fine-tuning job.", + max_items=5, + ), + ] = None + seed: Annotated[int, Field(description="The seed used for the fine-tuning job.")] + estimated_finish: Annotated[ + Optional[int], + Field( + description="The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running." + ), + ] = None + + +class AssistantObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["assistant"], + Field(description="The object type, which is always `assistant`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the assistant was created."), + ] + name: Annotated[ + str, + Field( + description="The name of the assistant. The maximum length is 256 characters.\n", + max_length=256, + ), + ] + description: Annotated[ + str, + Field( + description="The description of the assistant. The maximum length is 512 characters.\n", + max_length=512, + ), + ] + model: Annotated[ + str, + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n" + ), + ] + instructions: Annotated[ + str, + Field( + description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", + max_length=256000, + ), + ] + tools: Annotated[ + List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], + Field( + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", + max_items=128, + ), + ] + tool_resources: Annotated[ + Optional[ToolResources], + Field( + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] = None + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + example=1, + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + example=1, + ge=0.0, + le=1.0, + ), + ] = 1 + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class CreateAssistantRequest(BaseModel): + class Config: + extra = Extra.forbid + + model: Annotated[ + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + example="gpt-4o", + ), + ] + name: Annotated[ + Optional[str], + Field( + description="The name of the assistant. The maximum length is 256 characters.\n", + max_length=256, + ), + ] = None + description: Annotated[ + Optional[str], + Field( + description="The description of the assistant. The maximum length is 512 characters.\n", + max_length=512, + ), + ] = None + instructions: Annotated[ + Optional[str], + Field( + description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", + max_length=256000, + ), + ] = None + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", + max_items=128, + ), + ] = [] + tool_resources: Annotated[ + Optional[ToolResources1], + Field( + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + example=1, + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + example=1, + ge=0.0, + le=1.0, + ), + ] = 1 + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class ModifyAssistantRequest(BaseModel): + class Config: + extra = Extra.forbid + + model: Annotated[ + Optional[str], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n" + ), + ] = None + name: Annotated[ + Optional[str], + Field( + description="The name of the assistant. The maximum length is 256 characters.\n", + max_length=256, + ), + ] = None + description: Annotated[ + Optional[str], + Field( + description="The description of the assistant. The maximum length is 512 characters.\n", + max_length=512, + ), + ] = None + instructions: Annotated[ + Optional[str], + Field( + description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", + max_length=256000, + ), + ] = None + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", + max_items=128, + ), + ] = [] + tool_resources: Annotated[ + Optional[ToolResources2], + Field( + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + example=1, + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + example=1, + ge=0.0, + le=1.0, + ), + ] = 1 + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class ListAssistantsResponse(BaseModel): + object: Annotated[str, Field(example="list")] + data: List[AssistantObject] + first_id: Annotated[str, Field(example="asst_abc123")] + last_id: Annotated[str, Field(example="asst_abc456")] + has_more: Annotated[bool, Field(example=False)] + + +class AssistantsApiToolChoiceOption(BaseModel): + __root__: Annotated[ + Union[Literal["none", "auto", "required"], AssistantsNamedToolChoice], + Field( + description='Controls which (if any) tool is called by the model.\n`none` means the model will not call any tools and instead generates a message.\n`auto` is the default value and means the model can pick between generating a message or calling one or more tools.\n`required` means the model must call one or more tools before responding to the user.\nSpecifying a particular tool like `{"type": "file_search"}` or `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.\n' + ), + ] + + +class SubmitToolOutputs(BaseModel): + tool_calls: Annotated[ + List[RunToolCallObject], Field(description="A list of the relevant tool calls.") + ] + + +class RequiredAction(BaseModel): + type: Annotated[ + Literal["submit_tool_outputs"], + Field(description="For now, this is always `submit_tool_outputs`."), + ] + submit_tool_outputs: Annotated[ + SubmitToolOutputs, + Field(description="Details on the tool outputs needed for this run to continue."), + ] + + +class RunObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["thread.run"], + Field(description="The object type, which is always `thread.run`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run was created."), + ] + thread_id: Annotated[ + str, + Field( + description="The ID of the [thread](/docs/api-reference/threads) that was executed on as a part of this run." + ), + ] + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) used for execution of this run." + ), + ] + status: Annotated[ + Literal[ + "queued", + "in_progress", + "requires_action", + "cancelling", + "cancelled", + "failed", + "completed", + "incomplete", + "expired", + ], + Field( + description="The status of the run, which can be either `queued`, `in_progress`, `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, `incomplete`, or `expired`." + ), + ] + required_action: Annotated[ + RequiredAction, + Field( + description="Details on the action required to continue the run. Will be `null` if no action is required." + ), + ] + last_error: Annotated[ + LastError, + Field( + description="The last error associated with this run. Will be `null` if there are no errors." + ), + ] + expires_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run will expire."), + ] + started_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run was started."), + ] + cancelled_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run was cancelled."), + ] + failed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run failed."), + ] + completed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run was completed."), + ] + incomplete_details: Annotated[ + IncompleteDetails, + Field( + description="Details on why the run is incomplete. Will be `null` if the run is not incomplete." + ), + ] + model: Annotated[ + str, + Field( + description="The model that the [assistant](/docs/api-reference/assistants) used for this run." + ), + ] + instructions: Annotated[ + str, + Field( + description="The instructions that the [assistant](/docs/api-reference/assistants) used for this run." + ), + ] + tools: Annotated[ + List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], + Field( + description="The list of tools that the [assistant](/docs/api-reference/assistants) used for this run.", + max_items=20, + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + usage: RunCompletionUsage + temperature: Annotated[ + Optional[float], + Field(description="The sampling temperature used for this run. If not set, defaults to 1."), + ] = None + top_p: Annotated[ + Optional[float], + Field( + description="The nucleus sampling value used for this run. If not set, defaults to 1." + ), + ] = None + max_prompt_tokens: Annotated[ + int, + Field( + description="The maximum number of prompt tokens specified to have been used over the course of the run.\n", + ge=256, + ), + ] + max_completion_tokens: Annotated[ + int, + Field( + description="The maximum number of completion tokens specified to have been used over the course of the run.\n", + ge=256, + ), + ] + truncation_strategy: TruncationObject + tool_choice: AssistantsApiToolChoiceOption + parallel_tool_calls: ParallelToolCalls + response_format: AssistantsApiResponseFormatOption + + +class ListRunsResponse(BaseModel): + object: Annotated[str, Field(example="list")] + data: List[RunObject] + first_id: Annotated[str, Field(example="run_abc123")] + last_id: Annotated[str, Field(example="run_abc456")] + has_more: Annotated[bool, Field(example=False)] + + +class Content4(BaseModel): + __root__: Annotated[ + List[ + Union[ + MessageContentImageFileObject, + MessageContentImageUrlObject, + MessageRequestContentTextObject, + ] + ], + Field( + description="An array of content parts with a defined type, each can be of type `text` or images can be passed with `image_url` or `image_file`. Image types are only supported on [Vision-compatible models](/docs/models/overview).", + min_items=1, + title="Array of content parts", + ), + ] + + +class CreateMessageRequest(BaseModel): + class Config: + extra = Extra.forbid + + role: Annotated[ + Literal["user", "assistant"], + Field( + description="The role of the entity that is creating the message. Allowed values include:\n- `user`: Indicates the message is sent by an actual user and should be used in most cases to represent user-generated messages.\n- `assistant`: Indicates the message is generated by the assistant. Use this value to insert messages from the assistant into the conversation.\n" + ), + ] + content: Union[str, Content4] + attachments: Annotated[ + Optional[List[Attachment]], + Field( + description="A list of files attached to the message, and the tools they should be added to." + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class Text(BaseModel): + value: Annotated[str, Field(description="The data that makes up the text.")] + annotations: List[ + Union[ + MessageContentTextAnnotationsFileCitationObject, + MessageContentTextAnnotationsFilePathObject, + ] + ] + + +class MessageContentTextObject(BaseModel): + type: Annotated[Literal["text"], Field(description="Always `text`.")] + text: Text + + +class Text1(BaseModel): + value: Annotated[Optional[str], Field(description="The data that makes up the text.")] = None + annotations: Optional[ + List[ + Union[ + MessageDeltaContentTextAnnotationsFileCitationObject, + MessageDeltaContentTextAnnotationsFilePathObject, + ] + ] + ] = None + + +class MessageDeltaContentTextObject(BaseModel): + index: Annotated[int, Field(description="The index of the content part in the message.")] + type: Annotated[Literal["text"], Field(description="Always `text`.")] + text: Optional[Text1] = None + + +class CodeInterpreter7(BaseModel): + input: Annotated[str, Field(description="The input to the Code Interpreter tool call.")] + outputs: Annotated[ + List[ + Union[ + RunStepDetailsToolCallsCodeOutputLogsObject, + RunStepDetailsToolCallsCodeOutputImageObject, + ] + ], + Field( + description="The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type." + ), + ] + + +class RunStepDetailsToolCallsCodeObject(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call.")] + type: Annotated[ + Literal["code_interpreter"], + Field( + description="The type of tool call. This is always going to be `code_interpreter` for this type of tool call." + ), + ] + code_interpreter: Annotated[ + CodeInterpreter7, + Field(description="The Code Interpreter tool call definition."), + ] + + +class CodeInterpreter8(BaseModel): + input: Annotated[ + Optional[str], Field(description="The input to the Code Interpreter tool call.") + ] = None + outputs: Annotated[ + Optional[ + List[ + Union[ + RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject, + RunStepDeltaStepDetailsToolCallsCodeOutputImageObject, + ] + ] + ], + Field( + description="The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type." + ), + ] = None + + +class RunStepDeltaStepDetailsToolCallsCodeObject(BaseModel): + index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call.")] = None + type: Annotated[ + Literal["code_interpreter"], + Field( + description="The type of tool call. This is always going to be `code_interpreter` for this type of tool call." + ), + ] + code_interpreter: Annotated[ + Optional[CodeInterpreter8], + Field(description="The Code Interpreter tool call definition."), + ] = None + + +class CreateVectorStoreRequest(BaseModel): + class Config: + extra = Extra.forbid + + file_ids: Annotated[ + Optional[List[str]], + Field( + description="A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files.", + max_items=500, + ), + ] = None + name: Annotated[Optional[str], Field(description="The name of the vector store.")] = None + expires_after: Optional[VectorStoreExpirationAfter] = None + chunking_strategy: Annotated[ + Optional[Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam]], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. Only applicable if `file_ids` is non-empty." + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class StaticChunkingStrategyResponseParam(BaseModel): + class Config: + extra = Extra.forbid + + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: StaticChunkingStrategy + + +class RunStreamEvent1(BaseModel): + event: Literal["thread.run.created"] + data: RunObject + + +class RunStreamEvent2(BaseModel): + event: Literal["thread.run.queued"] + data: RunObject + + +class RunStreamEvent3(BaseModel): + event: Literal["thread.run.in_progress"] + data: RunObject + + +class RunStreamEvent4(BaseModel): + event: Literal["thread.run.requires_action"] + data: RunObject + + +class RunStreamEvent5(BaseModel): + event: Literal["thread.run.completed"] + data: RunObject + + +class RunStreamEvent6(BaseModel): + event: Literal["thread.run.incomplete"] + data: RunObject + + +class RunStreamEvent7(BaseModel): + event: Literal["thread.run.failed"] + data: RunObject + + +class RunStreamEvent8(BaseModel): + event: Literal["thread.run.cancelling"] + data: RunObject + + +class RunStreamEvent9(BaseModel): + event: Literal["thread.run.cancelled"] + data: RunObject + + +class RunStreamEvent10(BaseModel): + event: Literal["thread.run.expired"] + data: RunObject + + +class RunStreamEvent(BaseModel): + __root__: Union[ + RunStreamEvent1, + RunStreamEvent2, + RunStreamEvent3, + RunStreamEvent4, + RunStreamEvent5, + RunStreamEvent6, + RunStreamEvent7, + RunStreamEvent8, + RunStreamEvent9, + RunStreamEvent10, + ] + + +class ProjectServiceAccountCreateResponse(BaseModel): + object: Literal["organization.project.service_account"] + id: str + name: str + role: Annotated[ + Literal["member"], + Field(description="Service accounts can only have one role of type `member`"), + ] + created_at: int + api_key: ProjectServiceAccountApiKey + + +class ChatCompletionRequestAssistantMessage(BaseModel): + content: Annotated[ + Optional[Union[str, Content2]], + Field( + description="The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified.\n" + ), + ] = None + refusal: Annotated[ + Optional[str], Field(description="The refusal message by the assistant.") + ] = None + role: Annotated[ + Literal["assistant"], + Field(description="The role of the messages author, in this case `assistant`."), + ] + name: Annotated[ + Optional[str], + Field( + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role." + ), + ] = None + tool_calls: Optional[ChatCompletionMessageToolCalls] = None + function_call: Annotated[ + Optional[FunctionCall], + Field( + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." + ), + ] = None + + +class FineTuneChatCompletionRequestAssistantMessage(ChatCompletionRequestAssistantMessage): + weight: Annotated[ + Optional[Literal[0, 1]], + Field(description="Controls whether the assistant message is trained against (0 or 1)"), + ] = None + role: Annotated[ + Literal["assistant"], + Field(description="The role of the messages author, in this case `assistant`."), + ] + + +class ListPaginatedFineTuningJobsResponse(BaseModel): + data: List[FineTuningJob] + has_more: bool + object: Literal["list"] + + +class FinetuneChatRequestInput(BaseModel): + messages: Annotated[ + Optional[ + List[ + Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + FineTuneChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ] + ] + ], + Field(min_items=1), + ] = None + tools: Annotated[ + Optional[List[ChatCompletionTool]], + Field(description="A list of tools the model may generate JSON inputs for."), + ] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + functions: Annotated[ + Optional[List[ChatCompletionFunctions]], + Field( + description="A list of functions the model may generate JSON inputs for.", + max_items=128, + min_items=1, + ), + ] = None + + +class CreateRunRequest(BaseModel): + class Config: + extra = Extra.forbid + + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run." + ), + ] + model: Annotated[ + Optional[ + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ] + ], + Field( + description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", + example="gpt-4o", + ), + ] = None + instructions: Annotated[ + Optional[str], + Field( + description="Overrides the [instructions](/docs/api-reference/assistants/createAssistant) of the assistant. This is useful for modifying the behavior on a per-run basis." + ), + ] = None + additional_instructions: Annotated[ + Optional[str], + Field( + description="Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions." + ), + ] = None + additional_messages: Annotated[ + Optional[List[CreateMessageRequest]], + Field(description="Adds additional messages to the thread before creating the run."), + ] = None + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", + max_items=20, + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + example=1, + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + example=1, + ge=0.0, + le=1.0, + ), + ] = 1 + stream: Annotated[ + Optional[bool], + Field( + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n" + ), + ] = None + max_prompt_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] = None + max_completion_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] = None + truncation_strategy: Optional[TruncationObject] = None + tool_choice: Optional[AssistantsApiToolChoiceOption] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class CreateThreadRequest(BaseModel): + class Config: + extra = Extra.forbid + + messages: Annotated[ + Optional[List[CreateMessageRequest]], + Field( + description="A list of [messages](/docs/api-reference/messages) to start the thread with." + ), + ] = None + tool_resources: Annotated[ + Optional[ToolResources5], + Field( + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class MessageObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["thread.message"], + Field(description="The object type, which is always `thread.message`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the message was created."), + ] + thread_id: Annotated[ + str, + Field( + description="The [thread](/docs/api-reference/threads) ID that this message belongs to." + ), + ] + status: Annotated[ + Literal["in_progress", "incomplete", "completed"], + Field( + description="The status of the message, which can be either `in_progress`, `incomplete`, or `completed`." + ), + ] + incomplete_details: Annotated[ + IncompleteDetails1, + Field(description="On an incomplete message, details about why the message is incomplete."), + ] + completed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the message was completed."), + ] + incomplete_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the message was marked as incomplete." + ), + ] + role: Annotated[ + Literal["user", "assistant"], + Field(description="The entity that produced the message. One of `user` or `assistant`."), + ] + content: Annotated[ + List[ + Union[ + MessageContentImageFileObject, + MessageContentImageUrlObject, + MessageContentTextObject, + MessageContentRefusalObject, + ] + ], + Field(description="The content of the message in array of text and/or images."), + ] + assistant_id: Annotated[ + str, + Field( + description="If applicable, the ID of the [assistant](/docs/api-reference/assistants) that authored this message." + ), + ] + run_id: Annotated[ + str, + Field( + description="The ID of the [run](/docs/api-reference/runs) associated with the creation of this message. Value is `null` when messages are created manually using the create message or create thread endpoints." + ), + ] + attachments: Annotated[ + List[Attachment], + Field( + description="A list of files attached to the message, and the tools they were added to." + ), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + + +class Delta(BaseModel): + role: Annotated[ + Optional[Literal["user", "assistant"]], + Field(description="The entity that produced the message. One of `user` or `assistant`."), + ] = None + content: Annotated[ + Optional[ + List[ + Union[ + MessageDeltaContentImageFileObject, + MessageDeltaContentTextObject, + MessageDeltaContentRefusalObject, + MessageDeltaContentImageUrlObject, + ] + ] + ], + Field(description="The content of the message in array of text and/or images."), + ] = None + + +class MessageDeltaObject(BaseModel): + id: Annotated[ + str, + Field( + description="The identifier of the message, which can be referenced in API endpoints." + ), + ] + object: Annotated[ + Literal["thread.message.delta"], + Field(description="The object type, which is always `thread.message.delta`."), + ] + delta: Annotated[ + Delta, + Field(description="The delta containing the fields that have changed on the Message."), + ] + + +class ListMessagesResponse(BaseModel): + object: Annotated[str, Field(example="list")] + data: List[MessageObject] + first_id: Annotated[str, Field(example="msg_abc123")] + last_id: Annotated[str, Field(example="msg_abc123")] + has_more: Annotated[bool, Field(example=False)] + + +class RunStepDetailsToolCallsObject(BaseModel): + type: Annotated[Literal["tool_calls"], Field(description="Always `tool_calls`.")] + tool_calls: Annotated[ + List[ + Union[ + RunStepDetailsToolCallsCodeObject, + RunStepDetailsToolCallsFileSearchObject, + RunStepDetailsToolCallsFunctionObject, + ] + ], + Field( + description="An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`.\n" + ), + ] + + +class RunStepDeltaStepDetailsToolCallsObject(BaseModel): + type: Annotated[Literal["tool_calls"], Field(description="Always `tool_calls`.")] + tool_calls: Annotated[ + Optional[ + List[ + Union[ + RunStepDeltaStepDetailsToolCallsCodeObject, + RunStepDeltaStepDetailsToolCallsFileSearchObject, + RunStepDeltaStepDetailsToolCallsFunctionObject, + ] + ] + ], + Field( + description="An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`.\n" + ), + ] = None + + +class VectorStoreFileObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["vector_store.file"], + Field(description="The object type, which is always `vector_store.file`."), + ] + usage_bytes: Annotated[ + int, + Field( + description="The total vector store usage in bytes. Note that this may be different from the original file size." + ), + ] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the vector store file was created." + ), + ] + vector_store_id: Annotated[ + str, + Field( + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to." + ), + ] + status: Annotated[ + Literal["in_progress", "completed", "cancelled", "failed"], + Field( + description="The status of the vector store file, which can be either `in_progress`, `completed`, `cancelled`, or `failed`. The status `completed` indicates that the vector store file is ready for use." + ), + ] + last_error: Annotated[ + LastError2, + Field( + description="The last error associated with this vector store file. Will be `null` if there are no errors." + ), + ] + chunking_strategy: Annotated[ + Optional[Union[StaticChunkingStrategyResponseParam, OtherChunkingStrategyResponseParam]], + Field(description="The strategy used to chunk the file."), + ] = None + + +class ListVectorStoreFilesResponse(BaseModel): + object: Annotated[str, Field(example="list")] + data: List[VectorStoreFileObject] + first_id: Annotated[str, Field(example="file-abc123")] + last_id: Annotated[str, Field(example="file-abc456")] + has_more: Annotated[bool, Field(example=False)] + + +class MessageStreamEvent1(BaseModel): + event: Literal["thread.message.created"] + data: MessageObject + + +class MessageStreamEvent2(BaseModel): + event: Literal["thread.message.in_progress"] + data: MessageObject + + +class MessageStreamEvent3(BaseModel): + event: Literal["thread.message.delta"] + data: MessageDeltaObject + + +class MessageStreamEvent4(BaseModel): + event: Literal["thread.message.completed"] + data: MessageObject + + +class MessageStreamEvent5(BaseModel): + event: Literal["thread.message.incomplete"] + data: MessageObject + + +class MessageStreamEvent(BaseModel): + __root__: Union[ + MessageStreamEvent1, + MessageStreamEvent2, + MessageStreamEvent3, + MessageStreamEvent4, + MessageStreamEvent5, + ] + + +class ChatCompletionRequestMessage(BaseModel): + __root__: Annotated[ + Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + ChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ], + Field(discriminator="role"), + ] + + +class CreateChatCompletionRequest(BaseModel): + messages: Annotated[ + List[ChatCompletionRequestMessage], + Field( + description="A list of messages comprising the conversation so far. [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).", + min_items=1, + ), + ] + model: Annotated[ + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "chatgpt-4o-latest", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ], + Field( + description="ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.", + example="gpt-4o", + ), + ] + frequency_penalty: Annotated[ + Optional[float], + Field( + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] = 0 + logit_bias: Annotated[ + Optional[Dict[str, int]], + Field( + description="Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n" + ), + ] = None + logprobs: Annotated[ + Optional[bool], + Field( + description="Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`." + ), + ] = False + top_logprobs: Annotated[ + Optional[int], + Field( + description="An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used.", + ge=0, + le=20, + ), + ] = None + max_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of [tokens](/tokenizer) that can be generated in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n" + ), + ] = None + n: Annotated[ + Optional[int], + Field( + description="How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep `n` as `1` to minimize costs.", + example=1, + ge=1, + le=128, + ), + ] = 1 + presence_penalty: Annotated[ + Optional[float], + Field( + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] = 0 + response_format: Annotated[ + Optional[Union[ResponseFormatText, ResponseFormatJsonObject, ResponseFormatJsonSchema]], + Field( + description='An object specifying the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4o mini](/docs/models/gpt-4o-mini), [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs).\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n' + ), + ] = None + seed: Annotated[ + Optional[int], + Field( + description="This feature is in Beta.\nIf specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", + ge=-9223372036854775808, + le=9223372036854775807, + ), + ] = None + service_tier: Annotated[ + Optional[Literal["auto", "default"]], + Field( + description="Specifies the latency tier to use for processing the request. This parameter is relevant for customers subscribed to the scale tier service:\n - If set to 'auto', the system will utilize scale tier credits until they are exhausted.\n - If set to 'default', the request will be processed using the default service tier with a lower uptime SLA and no latency guarentee.\n - When not set, the default behavior is 'auto'.\n\n When this parameter is set, the response body will include the `service_tier` utilized.\n" + ), + ] = None + stop: Annotated[ + Optional[Union[str, Stop1]], + Field(description="Up to 4 sequences where the API will stop generating further tokens.\n"), + ] = None + stream: Annotated[ + Optional[bool], + Field( + description="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n" + ), + ] = False + stream_options: Optional[ChatCompletionStreamOptions] = None + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", + example=1, + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", + example=1, + ge=0.0, + le=1.0, + ), + ] = 1 + tools: Annotated[ + Optional[List[ChatCompletionTool]], + Field( + description="A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported.\n" + ), + ] = None + tool_choice: Optional[ChatCompletionToolChoiceOption] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + user: Annotated[ + Optional[str], + Field( + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + example="user-1234", + ), + ] = None + function_call: Annotated[ + Optional[Union[Literal["none", "auto"], ChatCompletionFunctionCallOption]], + Field( + description='Deprecated in favor of `tool_choice`.\n\nControls which (if any) function is called by the model.\n`none` means the model will not call a function and instead generates a message.\n`auto` means the model can pick between generating a message or calling a function.\nSpecifying a particular function via `{"name": "my_function"}` forces the model to call that function.\n\n`none` is the default when no functions are present. `auto` is the default if functions are present.\n' + ), + ] = None + functions: Annotated[ + Optional[List[ChatCompletionFunctions]], + Field( + description="Deprecated in favor of `tools`.\n\nA list of functions the model may generate JSON inputs for.\n", + max_items=128, + min_items=1, + ), + ] = None + + +class CreateThreadAndRunRequest(BaseModel): + class Config: + extra = Extra.forbid + + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run." + ), + ] + thread: Annotated[ + Optional[CreateThreadRequest], + Field(description="If no thread is provided, an empty thread will be created."), + ] = None + model: Annotated[ + Optional[ + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ] + ], + Field( + description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", + example="gpt-4o", + ), + ] = None + instructions: Annotated[ + Optional[str], + Field( + description="Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis." + ), + ] = None + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", + max_items=20, + ), + ] = None + tool_resources: Annotated[ + Optional[ToolResources3], + Field( + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + example=1, + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + example=1, + ge=0.0, + le=1.0, + ), + ] = 1 + stream: Annotated[ + Optional[bool], + Field( + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n" + ), + ] = None + max_prompt_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] = None + max_completion_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] = None + truncation_strategy: Optional[TruncationObject] = None + tool_choice: Optional[AssistantsApiToolChoiceOption] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class RunStepObject(BaseModel): + id: Annotated[ + str, + Field( + description="The identifier of the run step, which can be referenced in API endpoints." + ), + ] + object: Annotated[ + Literal["thread.run.step"], + Field(description="The object type, which is always `thread.run.step`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run step was created."), + ] + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) associated with the run step." + ), + ] + thread_id: Annotated[ + str, + Field(description="The ID of the [thread](/docs/api-reference/threads) that was run."), + ] + run_id: Annotated[ + str, + Field( + description="The ID of the [run](/docs/api-reference/runs) that this run step is a part of." + ), + ] + type: Annotated[ + Literal["message_creation", "tool_calls"], + Field( + description="The type of run step, which can be either `message_creation` or `tool_calls`." + ), + ] + status: Annotated[ + Literal["in_progress", "cancelled", "failed", "completed", "expired"], + Field( + description="The status of the run step, which can be either `in_progress`, `cancelled`, `failed`, `completed`, or `expired`." + ), + ] + step_details: Annotated[ + Union[RunStepDetailsMessageCreationObject, RunStepDetailsToolCallsObject], + Field(description="The details of the run step."), + ] + last_error: Annotated[ + LastError1, + Field( + description="The last error associated with this run step. Will be `null` if there are no errors." + ), + ] + expired_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the run step expired. A step is considered expired if the parent run is expired." + ), + ] + cancelled_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run step was cancelled."), + ] + failed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run step failed."), + ] + completed_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run step completed."), + ] + metadata: Annotated[ + Dict[str, Any], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + usage: RunStepCompletionUsage + + +class Delta1(BaseModel): + step_details: Annotated[ + Optional[ + Union[ + RunStepDeltaStepDetailsMessageCreationObject, + RunStepDeltaStepDetailsToolCallsObject, + ] + ], + Field(description="The details of the run step."), + ] = None + + +class RunStepDeltaObject(BaseModel): + id: Annotated[ + str, + Field( + description="The identifier of the run step, which can be referenced in API endpoints." + ), + ] + object: Annotated[ + Literal["thread.run.step.delta"], + Field(description="The object type, which is always `thread.run.step.delta`."), + ] + delta: Annotated[ + Delta1, + Field(description="The delta containing the fields that have changed on the run step."), + ] + + +class ListRunStepsResponse(BaseModel): + object: Annotated[str, Field(example="list")] + data: List[RunStepObject] + first_id: Annotated[str, Field(example="step_abc123")] + last_id: Annotated[str, Field(example="step_abc456")] + has_more: Annotated[bool, Field(example=False)] + + +class RunStepStreamEvent1(BaseModel): + event: Literal["thread.run.step.created"] + data: RunStepObject + + +class RunStepStreamEvent2(BaseModel): + event: Literal["thread.run.step.in_progress"] + data: RunStepObject + + +class RunStepStreamEvent3(BaseModel): + event: Literal["thread.run.step.delta"] + data: RunStepDeltaObject + + +class RunStepStreamEvent4(BaseModel): + event: Literal["thread.run.step.completed"] + data: RunStepObject + + +class RunStepStreamEvent5(BaseModel): + event: Literal["thread.run.step.failed"] + data: RunStepObject + + +class RunStepStreamEvent6(BaseModel): + event: Literal["thread.run.step.cancelled"] + data: RunStepObject + + +class RunStepStreamEvent7(BaseModel): + event: Literal["thread.run.step.expired"] + data: RunStepObject + + +class RunStepStreamEvent(BaseModel): + __root__: Union[ + RunStepStreamEvent1, + RunStepStreamEvent2, + RunStepStreamEvent3, + RunStepStreamEvent4, + RunStepStreamEvent5, + RunStepStreamEvent6, + RunStepStreamEvent7, + ] + + +class AssistantStreamEvent(BaseModel): + __root__: Annotated[ + Union[ + ThreadStreamEvent, + RunStreamEvent, + RunStepStreamEvent, + MessageStreamEvent, + ErrorEvent, + DoneEvent, + ], + Field( + description='Represents an event emitted when streaming a Run.\n\nEach event in a server-sent events stream has an `event` and `data` property:\n\n```\nevent: thread.created\ndata: {"id": "thread_123", "object": "thread", ...}\n```\n\nWe emit events whenever a new object is created, transitions to a new state, or is being\nstreamed in parts (deltas). For example, we emit `thread.run.created` when a new run\nis created, `thread.run.completed` when a run completes, and so on. When an Assistant chooses\nto create a message during a run, we emit a `thread.message.created event`, a\n`thread.message.in_progress` event, many `thread.message.delta` events, and finally a\n`thread.message.completed` event.\n\nWe may add additional events over time, so we recommend handling unknown events gracefully\nin your code. See the [Assistants API quickstart](/docs/assistants/overview) to learn how to\nintegrate the Assistants API with streaming.\n' + ), + ] diff --git a/clients/python/llmengine/data_types/model_endpoints.py b/clients/python/llmengine/data_types/model_endpoints.py new file mode 100644 index 00000000..2e087773 --- /dev/null +++ b/clients/python/llmengine/data_types/model_endpoints.py @@ -0,0 +1,212 @@ +from typing import Any, Dict, List, Optional + +from .core import ( + CallbackAuth, + CpuSpecificationType, + GpuType, + LLMInferenceFramework, + LLMSource, + ModelEndpointStatus, + ModelEndpointType, + Quantization, + StorageSpecificationType, +) +from .pydantic_types import BaseModel, Field, HttpUrl +from .rest import GetModelEndpointResponse +from .vllm import VLLMEndpointAdditionalArgs + + +class CreateLLMEndpointRequest(VLLMEndpointAdditionalArgs, BaseModel): + name: str + + # LLM specific fields + model_name: str + source: LLMSource = LLMSource.HUGGING_FACE + inference_framework: LLMInferenceFramework = LLMInferenceFramework.VLLM + inference_framework_image_tag: str = "latest" + num_shards: int = 1 + """ + Number of shards to distribute the model onto GPUs. + """ + + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. + """ + + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. + """ + + # General endpoint fields + metadata: Dict[str, Any] # TODO: JSON type + post_inference_hooks: Optional[List[str]] = None + endpoint_type: ModelEndpointType = ModelEndpointType.STREAMING + cpus: Optional[CpuSpecificationType] = None + gpus: Optional[int] = None + memory: Optional[StorageSpecificationType] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + nodes_per_worker: Optional[int] = None + optimize_costs: Optional[bool] = None + min_workers: int + max_workers: int + per_worker: int + labels: Dict[str, str] + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrl] = None + default_callback_auth: Optional[CallbackAuth] = None + public_inference: Optional[bool] = True # LLM endpoints are public by default. + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) + + +class CreateLLMEndpointResponse(BaseModel): + endpoint_creation_task_id: str + + +class GetLLMEndpointResponse(BaseModel): + """ + Response object for retrieving a Model. + """ + + id: Optional[str] = Field( + default=None, + description="(For self-hosted users) The autogenerated ID of the model.", + ) + """(For self-hosted users) The autogenerated ID of the model.""" + + name: str = Field( + description="The name of the model. Use this for making inference requests to the model." + ) + """The name of the model. Use this for making inference requests to the model.""" + + model_name: Optional[str] = Field( + default=None, + description="(For self-hosted users) For fine-tuned models, the base model. For base models, this will be the same as `name`.", + ) + """(For self-hosted users) For fine-tuned models, the base model. For base models, this will be the same as `name`.""" + + source: LLMSource = Field(description="The source of the model, e.g. Hugging Face.") + """The source of the model, e.g. Hugging Face.""" + + status: ModelEndpointStatus = Field(description="The status of the model.") + """The status of the model (can be one of "READY", "UPDATE_PENDING", "UPDATE_IN_PROGRESS", "UPDATE_FAILED", "DELETE_IN_PROGRESS").""" + + inference_framework: LLMInferenceFramework = Field( + description="The inference framework used by the model." + ) + """(For self-hosted users) The inference framework used by the model.""" + + inference_framework_tag: Optional[str] = Field( + default=None, + description="(For self-hosted users) The Docker image tag used to run the model.", + ) + """(For self-hosted users) The Docker image tag used to run the model.""" + + num_shards: Optional[int] = Field( + default=None, description="(For self-hosted users) The number of shards." + ) + """(For self-hosted users) The number of shards.""" + + quantize: Optional[Quantization] = Field( + default=None, description="(For self-hosted users) The quantization method." + ) + """(For self-hosted users) The quantization method.""" + + spec: Optional[GetModelEndpointResponse] = Field( + default=None, description="(For self-hosted users) Model endpoint details." + ) + """(For self-hosted users) Model endpoint details.""" + + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) + + +class ListLLMEndpointsResponse(BaseModel): + """ + Response object for listing Models. + """ + + model_endpoints: List[GetLLMEndpointResponse] = Field( + ..., + description="The list of models.", + ) + """ + A list of Models, represented as `GetLLMEndpointResponse`s. + """ + + +class UpdateLLMEndpointRequest(VLLMEndpointAdditionalArgs, BaseModel): + # LLM specific fields + model_name: Optional[str] = None + source: Optional[LLMSource] = None + inference_framework_image_tag: Optional[str] = None + num_shards: Optional[int] = None + """ + Number of shards to distribute the model onto GPUs. + """ + + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. + """ + + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. + """ + + # General endpoint fields + metadata: Optional[Dict[str, Any]] = None + post_inference_hooks: Optional[List[str]] = None + cpus: Optional[CpuSpecificationType] = None + gpus: Optional[int] = None + memory: Optional[StorageSpecificationType] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + optimize_costs: Optional[bool] = None + min_workers: Optional[int] = None + max_workers: Optional[int] = None + per_worker: Optional[int] = None + labels: Optional[Dict[str, str]] = None + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrl] = None + default_callback_auth: Optional[CallbackAuth] = None + public_inference: Optional[bool] = None + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) + + force_bundle_recreation: Optional[bool] = False + """ + Whether to force recreate the underlying bundle. + + If True, the underlying bundle will be recreated. This is useful if there are underlying implementation changes with how bundles are created + that we would like to pick up for existing endpoints + """ + + +class UpdateLLMEndpointResponse(BaseModel): + endpoint_creation_task_id: str + + +class DeleteLLMEndpointResponse(BaseModel): + """ + Response object for deleting a Model. + """ + + deleted: bool = Field(..., description="Whether deletion was successful.") + """ + Whether the deletion succeeded. + """ diff --git a/clients/python/llmengine/data_types/pydantic_types.py b/clients/python/llmengine/data_types/pydantic_types.py new file mode 100644 index 00000000..902f42ce --- /dev/null +++ b/clients/python/llmengine/data_types/pydantic_types.py @@ -0,0 +1,9 @@ +import pydantic + +PYDANTIC_V2 = hasattr(pydantic, "VERSION") and pydantic.VERSION.startswith("2.") + +if PYDANTIC_V2: + from pydantic.v1 import BaseModel, Field, HttpUrl # noqa: F401 + +else: + from pydantic import BaseModel, Field, HttpUrl # type: ignore # noqa: F401 diff --git a/clients/python/llmengine/data_types/rest.py b/clients/python/llmengine/data_types/rest.py new file mode 100644 index 00000000..f2978cd3 --- /dev/null +++ b/clients/python/llmengine/data_types/rest.py @@ -0,0 +1,284 @@ +""" +DTOs for LLM APIs. +""" + +import datetime +from enum import Enum +from typing import Any, Dict, List, Optional + +from .core import ( + CallbackAuth, + CpuSpecificationType, + GpuType, + ModelEndpointStatus, + ModelEndpointType, + StorageSpecificationType, +) +from .pydantic_types import BaseModel, Field, HttpUrl + + +class ModelEndpointDeploymentState(BaseModel): + """ + This is the entity-layer class for the deployment settings related to a Model Endpoint. + """ + + min_workers: int = Field(..., ge=0) + max_workers: int = Field(..., ge=0) + per_worker: int = Field(..., gt=0) + available_workers: Optional[int] = Field(default=None, ge=0) + unavailable_workers: Optional[int] = Field(default=None, ge=0) + + +class ModelEndpointResourceState(BaseModel): + """ + This is the entity-layer class for the resource settings per worker of a Model Endpoint. + Note: the values for cpus/gpus/memory/storage are per node, i.e. a single "worker" may consist of + multiple underlying "nodes" (corresponding to kubernetes pods), and the values for cpus/gpus/memory/storage + are the resources allocated for a single node. Thus, the total resource allocation + for the entire worker is multiplied by the value of `nodes_per_worker`. + """ + + cpus: CpuSpecificationType # TODO(phil): try to use decimal.Decimal + gpus: int = Field(..., ge=0) + memory: StorageSpecificationType + gpu_type: Optional[GpuType] + storage: Optional[StorageSpecificationType] + nodes_per_worker: int = Field(..., ge=1) # Multinode support. >1 = multinode. + optimize_costs: Optional[bool] + + +class GetModelEndpointResponse(BaseModel): + id: str + name: str + endpoint_type: ModelEndpointType + destination: str + deployment_name: Optional[str] = Field(default=None) + metadata: Optional[Dict[str, Any]] = Field(default=None) # TODO: JSON type + bundle_name: str + status: ModelEndpointStatus + post_inference_hooks: Optional[List[str]] = Field(default=None) + default_callback_url: Optional[HttpUrl] = Field(default=None) + default_callback_auth: Optional[CallbackAuth] = Field(default=None) + labels: Optional[Dict[str, str]] = Field(default=None) + aws_role: Optional[str] = Field(default=None) + results_s3_bucket: Optional[str] = Field(default=None) + created_by: str + created_at: datetime.datetime + last_updated_at: datetime.datetime + deployment_state: Optional[ModelEndpointDeploymentState] = Field(default=None) + resource_state: Optional[ModelEndpointResourceState] = Field(default=None) + num_queued_items: Optional[int] = Field(default=None) + public_inference: Optional[bool] = Field(default=None) + + +class PostInferenceHooks(str, Enum): + """ + Post-inference hooks are functions that are called after inference is complete. + + Attributes: + CALLBACK: The callback hook is called with the inference response and the task ID. + """ + + # INSIGHT = "insight" + CALLBACK: str = "callback" + + +class CreateFineTuneRequest(BaseModel): + """ + Request object for creating a FineTune. + """ + + model: str = Field(..., description="Identifier of base model to train from.") + """Identifier of base model to train from.""" + + training_file: str = Field( + ..., + description="Path to file of training dataset. Dataset must be a csv with columns 'prompt' and 'response'.", + ) + """Path to file of training dataset. Dataset must be a csv with columns 'prompt' and 'response'.""" + + validation_file: Optional[str] = Field( + default=None, + description="Path to file of validation dataset. Has the same format as training_file. If not provided, we will generate a split from the training dataset.", + ) + """Path to file of validation dataset. Has the same format as training_file. If not provided, we will generate a split from the training dataset.""" + + hyperparameters: Optional[Dict[str, Any]] = Field( + default=None, description="Hyperparameters to pass in to training job." + ) + """Hyperparameters to pass in to training job.""" + + wandb_config: Optional[Dict[str, Any]] = Field( + default=None, description="Configuration for Weights and Biases." + ) + """ + A dict of configuration parameters for Weights & Biases. See [Weights & Biases](https://docs.wandb.ai/ref/python/init) for more information. + Set `hyperparameter["report_to"]` to `wandb` to enable automatic finetune metrics logging. + Must include `api_key` field which is the wandb API key. + Also supports setting `base_url` to use a custom Weights & Biases server. + """ + + suffix: Optional[str] = Field( + default=None, + description="Optional user-provided identifier suffix for the fine-tuned model. Can be up to 28 characters long.", + ) + """Optional user-provided identifier suffix for the fine-tuned model. Can be up to 28 characters long.""" + + +class CreateFineTuneResponse(BaseModel): + """ + Response object for creating a FineTune. + """ + + id: str = Field(..., description="ID of the created fine-tuning job.") + """ + The ID of the FineTune. + """ + + +class BatchJobStatus(str, Enum): + PENDING = "PENDING" + RUNNING = "RUNNING" + SUCCESS = "SUCCESS" + FAILURE = "FAILURE" + CANCELLED = "CANCELLED" + UNDEFINED = "UNDEFINED" + TIMEOUT = "TIMEOUT" + + +class GetFineTuneResponse(BaseModel): + """ + Response object for retrieving a FineTune. + """ + + id: str = Field(..., description="ID of the requested job.") + """ + The ID of the FineTune. + """ + + fine_tuned_model: Optional[str] = Field( + default=None, + description="Name of the resulting fine-tuned model. This can be plugged into the " + "Completion API once the fine-tune is complete", + ) + """ + The name of the resulting fine-tuned model. This can be plugged into the Completion API + once the fine-tune is complete. + """ + + status: BatchJobStatus = Field(..., description="Status of the requested job.") + """ + The status of the FineTune job. + """ + + +class ListFineTunesResponse(BaseModel): + """ + Response object for listing FineTunes. + """ + + jobs: List[GetFineTuneResponse] = Field( + ..., description="List of fine-tuning jobs and their statuses." + ) + """ + A list of FineTunes, represented as `GetFineTuneResponse`s. + """ + + +class CancelFineTuneResponse(BaseModel): + """ + Response object for cancelling a FineTune. + """ + + success: bool = Field(..., description="Whether cancellation was successful.") + """ + Whether the cancellation succeeded. + """ + + +class LLMFineTuneEvent(BaseModel): + """ + Response object one FineTune event. + """ + + timestamp: Optional[float] = Field( + description="Timestamp of the event.", + default=None, + ) + message: str = Field(description="Message of the event.") + level: str = Field(description="Logging level of the event.") + + +class GetFineTuneEventsResponse(BaseModel): + """ + Response object for getting events for a FineTune. + """ + + events: List[LLMFineTuneEvent] = Field(..., description="List of fine-tuning events.") + + +class ModelDownloadRequest(BaseModel): + """ + Request object for downloading a model. + """ + + model_name: str = Field(..., description="Name of the model to download.") + download_format: Optional[str] = Field( + default="hugging_face", + description="Desired return format for downloaded model weights (default=hugging_face).", + ) + + +class ModelDownloadResponse(BaseModel): + """ + Response object for downloading a model. + """ + + urls: Dict[str, str] = Field( + ..., + description="Dictionary of (file_name, url) pairs to download the model from.", + ) + + +class UploadFileResponse(BaseModel): + """Response object for uploading a file.""" + + id: str = Field(..., description="ID of the uploaded file.") + """ID of the uploaded file.""" + + +class GetFileResponse(BaseModel): + """Response object for retrieving a file.""" + + id: str = Field(..., description="ID of the requested file.") + """ID of the requested file.""" + + filename: str = Field(..., description="File name.") + """File name.""" + + size: int = Field(..., description="Length of the file, in characters.") + """Length of the file, in characters.""" + + +class ListFilesResponse(BaseModel): + """Response object for listing files.""" + + files: List[GetFileResponse] = Field(..., description="List of file IDs, names, and sizes.") + """List of file IDs, names, and sizes.""" + + +class DeleteFileResponse(BaseModel): + """Response object for deleting a file.""" + + deleted: bool = Field(..., description="Whether deletion was successful.") + """Whether deletion was successful.""" + + +class GetFileContentResponse(BaseModel): + """Response object for retrieving a file's content.""" + + id: str = Field(..., description="ID of the requested file.") + """ID of the requested file.""" + + content: str = Field(..., description="File content.") + """File content.""" diff --git a/clients/python/llmengine/data_types/vllm.py b/clients/python/llmengine/data_types/vllm.py new file mode 100644 index 00000000..831dcf8d --- /dev/null +++ b/clients/python/llmengine/data_types/vllm.py @@ -0,0 +1,243 @@ +from typing import Any, Dict, List, Optional, Union + +from .gen.openai import ResponseFormatJsonObject, ResponseFormatJsonSchema, ResponseFormatText +from .pydantic_types import BaseModel, Field + +# This was last synced w/ vLLM v0.5.5 on 2024-09-03 + + +class VLLMModelConfig(BaseModel): + """Model configuration for VLLM""" + + max_model_len: Optional[int] = Field( + None, + description="""Model context length, If unspecified, will be automatically derived from the model config""", + ) + + max_num_seqs: Optional[int] = Field( + None, + description="""Maximum number of sequences per iteration""", + ) + + enforce_eager: Optional[bool] = Field( + None, + description="""Always use eager-mode PyTorch. If False, will use eager mode and CUDA graph in hybrid for maximal perforamnce and flexibility""", + ) + + gpu_memory_utilization: Optional[float] = Field( + None, + description="Maximum GPU memory utilization for the batch inference. Default to 90%.", + ) + + trust_remote_code: Optional[bool] = Field( + default=False, + description="Whether to trust remote code from Hugging face hub. This is only applicable to models whose code is not supported natively by the transformers library (e.g. deepseek). Default to False.", + ) + + +class VLLMEngineAdditionalArgs(BaseModel): + """Additional arguments to configure for vLLM that are not direct inputs to the vLLM engine""" + + max_gpu_memory_utilization: Optional[float] = Field( + None, + description="Maximum GPU memory utilization for the batch inference. Default to 90%. Deprecated in favor of specifying this in VLLMModelConfig", + ) + + attention_backend: Optional[str] = Field( + default=None, + description="Attention backend to use for vLLM. Default to None.", + ) + + +class VLLMEndpointAdditionalArgs(VLLMModelConfig, VLLMEngineAdditionalArgs, BaseModel): + pass + + +class VLLMSamplingParams(BaseModel): + best_of: Optional[int] = Field( + None, + description="""Number of output sequences that are generated from the prompt. + From these `best_of` sequences, the top `n` sequences are returned. + `best_of` must be greater than or equal to `n`. This is treated as + the beam width when `use_beam_search` is True. By default, `best_of` + is set to `n`.""", + ) + top_k: Optional[int] = Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ) + min_p: Optional[float] = Field( + None, + description="""Float that represents the minimum probability for a token to be + considered, relative to the probability of the most likely token. + Must be in [0, 1]. Set to 0 to disable this.""", + ) + use_beam_search: Optional[bool] = Field( + None, + description="""Whether to use beam search for sampling.""", + ) + length_penalty: Optional[float] = Field( + default=None, + description="""Float that penalizes sequences based on their length. + Used in beam search.""", + ) + repetition_penalty: Optional[float] = Field( + default=None, + description="""Float that penalizes new tokens based on whether + they appear in the prompt and the generated text so far. Values > 1 + encourage the model to use new tokens, while values < 1 encourage + the model to repeat tokens.""", + ) + early_stopping: Optional[bool] = Field( + None, + description="""Controls the stopping condition for beam search. It + accepts the following values: `True`, where the generation stops as + soon as there are `best_of` complete candidates; `False`, where an + heuristic is applied and the generation stops when is it very + unlikely to find better candidates; `"never"`, where the beam search + procedure only stops when there cannot be better candidates + (canonical beam search algorithm).""", + ) + stop_token_ids: Optional[List[int]] = Field( + default_factory=list, + description="""List of tokens that stop the generation when they are + generated. The returned output will contain the stop tokens unless + the stop tokens are special tokens.""", + ) + include_stop_str_in_output: Optional[bool] = Field( + None, + description="""Whether to include the stop strings in + output text. Defaults to False.""", + ) + ignore_eos: Optional[bool] = Field( + None, + description="""Whether to ignore the EOS token and continue generating + tokens after the EOS token is generated.""", + ) + min_tokens: Optional[int] = Field( + None, + description="""Minimum number of tokens to generate per output sequence + before EOS or stop_token_ids can be generated""", + ) + + skip_special_tokens: Optional[bool] = Field( + True, + description="Whether to skip special tokens in the output. Only supported in vllm.", + ) + + spaces_between_special_tokens: Optional[bool] = Field( + True, + description="Whether to add spaces between special tokens in the output. Only supported in vllm.", + ) + + +class VLLMChatCompletionAdditionalParams(VLLMSamplingParams): + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the model's tokenizer " + "does not define one and no override template is given" + ), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=( + "Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template." + ), + ) + + guided_json: Optional[Dict[str, Any]] = Field( + default=None, + description="JSON schema for guided decoding. Only supported in vllm.", + ) + + guided_regex: Optional[str] = Field( + default=None, + description="Regex for guided decoding. Only supported in vllm.", + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description="Choices for guided decoding. Only supported in vllm.", + ) + + guided_grammar: Optional[str] = Field( + default=None, + description="Context-free grammar for guided decoding. Only supported in vllm.", + ) + + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'" + ), + ) + + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding." + ), + ) + + +class VLLMCompletionAdditionalParams(VLLMSamplingParams): + add_special_tokens: Optional[bool] = Field( + default=None, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " "the prompt." + ), + ) + + response_format: Optional[ + Union[ResponseFormatText, ResponseFormatJsonObject, ResponseFormatJsonSchema] + ] = Field( + default=None, + description=( + "Similar to chat completion, this parameter specifies the format of " + "output. Only {'type': 'json_object'} or {'type': 'text' } is " + "supported." + ), + ) + + guided_json: Optional[Dict[str, Any]] = Field( + default=None, + description="JSON schema for guided decoding. Only supported in vllm.", + ) + + guided_regex: Optional[str] = Field( + default=None, + description="Regex for guided decoding. Only supported in vllm.", + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description="Choices for guided decoding. Only supported in vllm.", + ) + + guided_grammar: Optional[str] = Field( + default=None, + description="Context-free grammar for guided decoding. Only supported in vllm.", + ) + + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'" + ), + ) + + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding." + ), + ) diff --git a/clients/python/llmengine/errors.py b/clients/python/llmengine/errors.py index 3383878a..27008024 100644 --- a/clients/python/llmengine/errors.py +++ b/clients/python/llmengine/errors.py @@ -81,7 +81,7 @@ def parse_error(status_code: int, content: bytes) -> Exception: try: payload = json.loads(content) message = payload["detail"] - except json.JSONDecodeError: + except (json.JSONDecodeError, KeyError): message = content.decode("utf-8") # Try to parse a APIInference error @@ -93,7 +93,7 @@ def parse_error(status_code: int, content: bytes) -> Exception: return NotFoundError(message) if status_code == 429: return RateLimitExceededError(message) - if 600 < status_code <= 500: + if 500 <= status_code < 600: return ServerError(status_code, message) # Fallback to an unknown error diff --git a/clients/python/llmengine/file.py b/clients/python/llmengine/file.py new file mode 100644 index 00000000..670efda3 --- /dev/null +++ b/clients/python/llmengine/file.py @@ -0,0 +1,196 @@ +from io import BufferedReader + +from llmengine.api_engine import DEFAULT_TIMEOUT, APIEngine +from llmengine.data_types import ( + DeleteFileResponse, + GetFileContentResponse, + GetFileResponse, + ListFilesResponse, + UploadFileResponse, +) + + +class File(APIEngine): + """ + File API. This API is used to upload private files to LLM engine so that fine-tunes can access them for training and validation data. + + Functions are provided to upload, get, list, and delete files, as well as to get the contents of a file. + """ + + @classmethod + def upload(cls, file: BufferedReader) -> UploadFileResponse: + """ + Uploads a file to LLM engine. + + For use in [FineTune creation](./#llmengine.fine_tuning.FineTune.create), this should be a CSV file with two columns: `prompt` and `response`. + A maximum of 100,000 rows of data is currently supported. + + Args: + file (`BufferedReader`): + A local file opened with `open(file_path, "r")` + + Returns: + UploadFileResponse: an object that contains the ID of the uploaded file + + === "Uploading file in Python" + ```python + from llmengine import File + + response = File.upload(open("training_dataset.csv", "r")) + + print(response.json()) + ``` + + === "Response in JSON" + ```json + { + "id": "file-abc123" + } + ``` + """ + files = {"file": file} + response = cls.post_file( + resource_name="v1/files", + files=files, + timeout=DEFAULT_TIMEOUT, + ) + return UploadFileResponse.parse_obj(response) + + @classmethod + def get(cls, file_id: str) -> GetFileResponse: + """ + Get file metadata, including filename and size. + + Args: + file_id (`str`): + ID of the file + + Returns: + GetFileResponse: an object that contains the ID, filename, and size of the requested file + + === "Getting metadata about file in Python" + ```python + from llmengine import File + + response = File.get( + file_id="file-abc123", + ) + + print(response.json()) + ``` + + === "Response in JSON" + ```json + { + "id": "file-abc123", + "filename": "training_dataset.csv", + "size": 100 + } + ``` + """ + response = cls._get(f"v1/files/{file_id}", timeout=DEFAULT_TIMEOUT) + return GetFileResponse.parse_obj(response) + + @classmethod + def list(cls) -> ListFilesResponse: + """ + List metadata about all files, e.g. their filenames and sizes. + + Returns: + ListFilesResponse: an object that contains a list of all files and their filenames and sizes + + === "Listing files in Python" + ```python + from llmengine import File + + response = File.list() + print(response.json()) + ``` + + === "Response in JSON" + ```json + { + "files": [ + { + "id": "file-abc123", + "filename": "training_dataset.csv", + "size": 100 + }, + { + "id": "file-def456", + "filename": "validation_dataset.csv", + "size": 50 + } + ] + } + ``` + """ + response = cls._get("v1/files", timeout=30) + return ListFilesResponse.parse_obj(response) + + @classmethod + def delete(cls, file_id: str) -> DeleteFileResponse: + """ + Deletes a file. + + Args: + file_id (`str`): + ID of the file + + Returns: + DeleteFileResponse: an object that contains whether the deletion was successful + + === "Deleting file in Python" + ```python + from llmengine import File + + response = File.delete(file_id="file-abc123") + print(response.json()) + ``` + + === "Response in JSON" + ```json + { + "deleted": true + } + ``` + """ + response = cls._delete( + f"v1/files/{file_id}", + timeout=DEFAULT_TIMEOUT, + ) + return DeleteFileResponse.parse_obj(response) + + @classmethod + def download(cls, file_id: str) -> GetFileContentResponse: + """ + Get contents of a file, as a string. (If the uploaded file is in binary, a string encoding will be returned.) + + Args: + file_id (`str`): + ID of the file + + Returns: + GetFileContentResponse: an object that contains the ID and content of the file + + === "Getting file content in Python" + ```python + from llmengine import File + + response = File.download(file_id="file-abc123") + print(response.json()) + ``` + + === "Response in JSON" + ```json + { + "id": "file-abc123", + "content": "Hello world!" + } + ``` + """ + response = cls._get( + f"v1/files/{file_id}/content", + timeout=DEFAULT_TIMEOUT, + ) + return GetFileContentResponse.parse_obj(response) diff --git a/clients/python/llmengine/fine_tuning.py b/clients/python/llmengine/fine_tuning.py index 62c147a6..b0f73d6b 100644 --- a/clients/python/llmengine/fine_tuning.py +++ b/clients/python/llmengine/fine_tuning.py @@ -1,4 +1,4 @@ -from typing import Dict, Optional +from typing import Any, Dict, Optional, Union from llmengine.api_engine import DEFAULT_TIMEOUT, APIEngine from llmengine.data_types import ( @@ -28,7 +28,8 @@ def create( model: str, training_file: str, validation_file: Optional[str] = None, - hyperparameters: Optional[Dict[str, str]] = None, + hyperparameters: Optional[Dict[str, Union[str, int, float]]] = None, + wandb_config: Optional[Dict[str, Any]] = None, suffix: Optional[str] = None, ) -> CreateFineTuneResponse: """ @@ -37,8 +38,10 @@ def create( This API can be used to fine-tune a model. The _model_ is the name of base model ([Model Zoo](../../model_zoo) for available models) to fine-tune. The training and validation files should consist of prompt and response pairs. `training_file` - and `validation_file` must be publicly accessible HTTP or HTTPS URLs to a CSV file - that includes two columns: `prompt` and `response`. A maximum of 100,000 rows of data is + and `validation_file` must be either publicly accessible HTTP or HTTPS URLs, or + file IDs of files uploaded to LLM Engine's [Files API](./#llmengine.File) (these + will have the `file-` prefix). The referenced files must be CSV files that include + two columns: `prompt` and `response`. A maximum of 100,000 rows of data is currently supported. At least 200 rows of data is recommended to start to see benefits from fine-tuning. For sequences longer than the native `max_seq_length` of the model, the sequences will be truncated. @@ -51,12 +54,12 @@ def create( The name of the base model to fine-tune. See [Model Zoo](../../model_zoo) for the list of available models to fine-tune. training_file (`str`): - Publicly accessible URL to a CSV file for training. + Publicly accessible URL or file ID referencing a CSV file for training. When no validation_file is provided, one will automatically be created using a 10% split of the training_file data. validation_file (`Optional[str]`): - Publicly accessible URL to a CSV file for validation. The validation file is used to compute metrics which let LLM Engine pick the best fine-tuned checkpoint, which will be used for inference when fine-tuning is complete. + Publicly accessible URL or file ID referencing a CSV file for validation. The validation file is used to compute metrics which let LLM Engine pick the best fine-tuned checkpoint, which will be used for inference when fine-tuning is complete. - hyperparameters (`Optional[Dict[str, str]]`): + hyperparameters (`Optional[Dict[str, Union[str, int, float, Dict[str, Any]]]]`): A dict of hyperparameters to customize fine-tuning behavior. Currently supported hyperparameters: @@ -65,13 +68,21 @@ def create( * `warmup_ratio`: Ratio of training steps used for learning rate warmup. (Default: 0.03) * `epochs`: Number of fine-tuning epochs. This should be less than 20. (Default: 5) * `weight_decay`: Regularization penalty applied to learned weights. (Default: 0.001) + * `peft_config`: A dict of parameters for the PEFT algorithm. See [LoraConfig](https://huggingface.co/docs/peft/main/en/package_reference/tuners#peft.LoraConfig) for more information. + + wandb_config (`Optional[Dict[str, Any]]`): + A dict of configuration parameters for Weights & Biases. See [Weights & Biases](https://docs.wandb.ai/ref/python/init) for more information. + Set `hyperparameter["report_to"]` to `wandb` to enable automatic finetune metrics logging. + Must include `api_key` field which is the wandb API key. + Also supports setting `base_url` to use a custom Weights & Biases server. suffix (`Optional[str]`): A string that will be added to your fine-tuned model name. If present, the entire fine-tuned model name - will be formatted like `"[model].[suffix].[YYYY-MM-DD-HH-MM-SS]"`. If absent, the - fine-tuned model name will be formatted `"[model].[YYYY-MM-DD-HH-MM-SS]"`. + will be formatted like `"[model].[suffix].[YYMMDD-HHMMSS]"`. If absent, the + fine-tuned model name will be formatted `"[model].[YYMMDD-HHMMSS]"`. For example, if `suffix` is `"my-experiment"`, the fine-tuned model name could be - `"llama-7b.my-experiment.2023-07-17-23-01-50"`. + `"llama-2-7b.my-experiment.230717-230150"`. + Note: `suffix` must be between 1 and 28 characters long, and can only contain alphanumeric characters and hyphens. Returns: CreateFineTuneResponse: an object that contains the ID of the created fine-tuning job @@ -99,14 +110,14 @@ def create( writer.writerows(data) ``` - Currently, data needs to be uploaded to a publicly accessible web URL so that it can be read - for fine-tuning. Publicly accessible HTTP and HTTPS URLs are currently supported. - Support for privately sharing data with the LLM Engine API is coming shortly. For quick - iteration, you can look into tools like Pastebin or GitHub Gists to quickly host your CSV - files in a public manner. An example Github Gist can be found - [here](https://gist.github.com/tigss/7cec73251a37de72756a3b15eace9965). To use the gist, - you can use the URL given when you click the “Raw” button - ([URL](https://gist.githubusercontent.com/tigss/7cec73251a37de72756a3b15eace9965/raw/85d9742890e1e6b0c06468507292893b820c13c9/llm_sample_data.csv)). + Currently, data needs to be uploaded to either a publicly accessible web URL or to LLM Engine's + private file server so that it can be read for fine-tuning. Publicly accessible HTTP and HTTPS + URLs are currently supported. + + To privately share data with the LLM Engine API, use LLM Engine's [File.upload](../../api/python_client/#llmengine.File.upload) + API. You can upload data in local file to LLM Engine's private file server and then use the + returned file ID to reference your data in the FineTune API. The file ID is generally in the + form of `file-`, e.g. "file-7DLVeLdN2Ty4M2m". Example code for fine-tuning: === "Fine-tuning in Python" @@ -114,8 +125,8 @@ def create( from llmengine import FineTune response = FineTune.create( - model="llama-7b", - training_file="https://my-bucket.s3.us-west-2.amazonaws.com/path/to/training-file.csv", + model="llama-2-7b", + training_file="file-7DLVeLdN2Ty4M2m", ) print(response.json()) @@ -134,6 +145,7 @@ def create( training_file=training_file, validation_file=validation_file, hyperparameters=hyperparameters, + wandb_config=wandb_config, suffix=suffix, ) response = cls.post_sync( @@ -283,7 +295,7 @@ def get_events(cls, fine_tune_id: str) -> GetFineTuneEventsResponse: Returns: GetFineTuneEventsResponse: an object that contains the list of events for the fine-tuning job - Example: + === "Getting events for fine-tuning jobs in Python" ```python from llmengine import FineTune @@ -291,7 +303,7 @@ def get_events(cls, fine_tune_id: str) -> GetFineTuneEventsResponse: print(response.json()) ``` - JSON Response: + === "Response in JSON" ```json { "events": diff --git a/clients/python/llmengine/model.py b/clients/python/llmengine/model.py index bbf15843..13527a67 100644 --- a/clients/python/llmengine/model.py +++ b/clients/python/llmengine/model.py @@ -1,8 +1,22 @@ -from llmengine.api_engine import DEFAULT_TIMEOUT, APIEngine +from typing import Dict, List, Optional + +from llmengine.api_engine import DEFAULT_TIMEOUT, APIEngine, assert_self_hosted from llmengine.data_types import ( + CreateLLMEndpointRequest, + CreateLLMEndpointResponse, DeleteLLMEndpointResponse, GetLLMEndpointResponse, + GpuType, ListLLMEndpointsResponse, + LLMInferenceFramework, + LLMSource, + ModelDownloadRequest, + ModelDownloadResponse, + ModelEndpointType, + PostInferenceHooks, + Quantization, + UpdateLLMEndpointRequest, + UpdateLLMEndpointResponse, ) @@ -15,10 +29,320 @@ class Model(APIEngine): See [Model Zoo](../../model_zoo) for the list of publicly available base models. """ + @classmethod + @assert_self_hosted + def create( + cls, + name: str, + # LLM specific fields + model: str, + inference_framework_image_tag: str, + source: LLMSource = LLMSource.HUGGING_FACE, + inference_framework: LLMInferenceFramework = LLMInferenceFramework.VLLM, + num_shards: int = 1, + quantize: Optional[Quantization] = None, + checkpoint_path: Optional[str] = None, + max_model_len: Optional[int] = None, + # General endpoint fields + cpus: Optional[int] = None, + memory: Optional[str] = None, + storage: Optional[str] = None, + gpus: Optional[int] = None, + nodes_per_worker: int = 1, + min_workers: int = 0, + max_workers: int = 1, + per_worker: int = 2, + endpoint_type: ModelEndpointType = ModelEndpointType.STREAMING, + gpu_type: Optional[str] = None, + high_priority: Optional[bool] = False, + post_inference_hooks: Optional[List[PostInferenceHooks]] = None, + default_callback_url: Optional[str] = None, + public_inference: Optional[bool] = True, + labels: Optional[Dict[str, str]] = None, + request_headers: Optional[Dict[str, str]] = None, + ) -> CreateLLMEndpointResponse: + """ + Create an LLM model. Note: This API is only available for self-hosted users. + + Args: + name (`str`): + Name of the endpoint + + model (`str`): + Name of the base model + + inference_framework_image_tag (`str`): + Image tag for the inference framework. Use "latest" for the most recent image + + source (`LLMSource`): + Source of the LLM. Currently only HuggingFace is supported + + inference_framework (`LLMInferenceFramework`): + Inference framework for the LLM. Current supported frameworks are + LLMInferenceFramework.DEEPSPEED, LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + LLMInferenceFramework.VLLM and LLMInferenceFramework.LIGHTLLM + + num_shards (`int`): + Number of shards for the LLM. When bigger than 1, LLM will be sharded + to multiple GPUs. Number of GPUs must be equal or larger than num_shards. + + quantize (`Optional[Quantization]`): + Quantization method for the LLM. `text_generation_inference` supports `bitsandbytes` and `vllm` supports `awq`. + + checkpoint_path (`Optional[str]`): + Remote path to the checkpoint for the LLM. LLM engine must have permission to access the given path. + Can be either a folder or a tar file. Folder is preferred since we don't need to untar and model loads faster. + For model weights, safetensors are preferred but PyTorch checkpoints are also accepted (model loading will be longer). + + max_model_len (`Optional[int]`): + Model context length. If unspecified, will be automatically derived from the model config. + + cpus (`Optional[int]`): + Number of cpus each node in the worker should get, e.g. 1, 2, etc. This must be greater + than or equal to 1. Recommendation is set it to 8 * GPU count. Can be inferred from the model size. + + memory (`Optional[str]`): + Amount of memory each node in the worker should get, e.g. "4Gi", "512Mi", etc. This must + be a positive amount of memory. Recommendation is set it to 24Gi * GPU count. + Can be inferred from the model size. + + storage (`Optional[str]`): + Amount of local ephemeral storage each node in the worker should get, e.g. "4Gi", + "512Mi", etc. This must be a positive amount of storage. + Recommendataion is 40Gi for 7B models, 80Gi for 13B models and 200Gi for 70B models. + Can be inferred from the model size. + + gpus (`Optional[int]`): + Number of gpus each node in the worker should get, e.g. 0, 1, etc. Can be inferred from the model size. + + nodes_per_worker (`int`): + Number of nodes per worker. Used to request multinode serving. This must be greater than or equal to 1. + Controls how many nodes to dedicate to one instance of the model. + Specifically, if `nodes_per_worker` is set to greater than 1, the model will be sharded across + `nodes_per_worker` nodes (e.g. kubernetes pods). One of these nodes will be a "leader" node and receive requests. + LLM Engine will set up the inter-node communication. + Any compute resource requests (i.e. cpus, memory, storage) apply to each individual node, thus the total resources + allocated are multiplied by this number. This is useful for models that require more memory than a single node can provide. + Note: autoscaling is not supported for multinode serving. + Further note: if your model can fit on GPUs on only one machine, e.g. you have access to an 8xA100 machine and your model fits + on 8 A100s, it is recommended to set `nodes_per_worker` to 1 and the rest of the resources accordingly. + `nodes_per_worker > 1` should only be set if you require more resources than a single machine can provide. + + min_workers (`int`): + The minimum number of workers. Must be greater than or equal to 0. This + should be determined by computing the minimum throughput of your workload and + dividing it by the throughput of a single worker. When this number is 0, + max_workers must be 1, and the endpoint will autoscale between + 0 and 1 pods. When this number is greater than 0, max_workers can be any number + greater or equal to min_workers. + + max_workers (`int`): + The maximum number of workers. Must be greater than or equal to 0, + and as well as greater than or equal to ``min_workers``. This should be determined by + computing the maximum throughput of your workload and dividing it by the throughput + of a single worker + + per_worker (`int`): + The maximum number of concurrent requests that an individual worker can + service. LLM engine automatically scales the number of workers for the endpoint so that + each worker is processing ``per_worker`` requests, subject to the limits defined by + ``min_workers`` and ``max_workers`` + - If the average number of concurrent requests per worker is lower than + ``per_worker``, then the number of workers will be reduced. - Otherwise, + if the average number of concurrent requests per worker is higher than + ``per_worker``, then the number of workers will be increased to meet the elevated + traffic. + Here is our recommendation for computing ``per_worker``: + 1. Compute ``min_workers`` and ``max_workers`` per your minimum and maximum + throughput requirements. 2. Determine a value for the maximum number of + concurrent requests in the workload. Divide this number by ``max_workers``. Doing + this ensures that the number of workers will "climb" to ``max_workers``. + + endpoint_type (`ModelEndpointType`): + Currently only ``"streaming"`` endpoints are supported. + + gpu_type (`Optional[str]`): + If specifying a non-zero number of gpus, this controls the type of gpu + requested. Can be inferred from the model size. Here are the supported values: + + - ``nvidia-tesla-t4`` + - ``nvidia-ampere-a10`` + - ``nvidia-ampere-a100`` + - ``nvidia-ampere-a100e`` + - ``nvidia-hopper-h100`` + - ``nvidia-hopper-h100-1g20gb`` # 1 slice of MIG with 1g compute and 20GB memory + - ``nvidia-hopper-h100-3g40gb`` # 1 slice of MIG with 3g compute and 40GB memory + + high_priority (`Optional[bool]`): + Either ``True`` or ``False``. Enabling this will allow the created + endpoint to leverage the shared pool of prewarmed nodes for faster spinup time + + post_inference_hooks (`Optional[List[PostInferenceHooks]]`): + List of hooks to trigger after inference tasks are served + + default_callback_url (`Optional[str]`): + The default callback url to use for sync completion requests. + This can be overridden in the task parameters for each individual task. + post_inference_hooks must contain "callback" for the callback to be triggered + + public_inference (`Optional[bool]`): + If ``True``, this endpoint will be available to all user IDs for + inference + + labels (`Optional[Dict[str, str]]`): + An optional dictionary of key/value pairs to associate with this endpoint + Returns: + CreateLLMEndpointResponse: creation task ID of the created Model. Currently not used. + + === "Create Llama 2 70B model with hardware specs inferred in Python" + ```python + from llmengine import Model + + response = Model.create( + name="llama-2-70b-test" + model="llama-2-70b", + inference_framework_image_tag="0.9.4", + inference_framework=LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + num_shards=4, + checkpoint_path="s3://path/to/checkpoint", + min_workers=0, + max_workers=1, + per_worker=10, + endpoint_type=ModelEndpointType.STREAMING, + public_inference=False, + ) + + print(response.json()) + ``` + === "Create Llama 2 7B model with hardware specs specified in Python" + ```python + from llmengine import Model + + response = Model.create( + name="llama-2-7b-test" + model="llama-2-7b", + inference_framework_image_tag="0.2.1.post1", + inference_framework=LLMInferenceFramework.VLLM, + num_shards=1, + checkpoint_path="s3://path/to/checkpoint", + cpus=8, + memory="24Gi", + storage="40Gi", + gpus=1, + min_workers=0, + max_workers=1, + per_worker=10, + endpoint_type=ModelEndpointType.STREAMING, + gpu_type="nvidia-ampere-a10", + public_inference=False, + ) + + print(response.json()) + ``` + + === "Create Llama 2 13B model in Python" + ```python + from llmengine import Model + + response = Model.create( + name="llama-2-13b-test" + model="llama-2-13b", + inference_framework_image_tag="0.2.1.post1", + inference_framework=LLMInferenceFramework.VLLM, + num_shards=2, + checkpoint_path="s3://path/to/checkpoint", + cpus=16, + memory="48Gi", + storage="80Gi", + gpus=2, + min_workers=0, + max_workers=1, + per_worker=10, + endpoint_type=ModelEndpointType.STREAMING, + gpu_type="nvidia-ampere-a10", + public_inference=False, + ) + + print(response.json()) + ``` + + === "Create Llama 2 70B model with 8bit quantization in Python" + ```python + from llmengine import Model + + response = Model.create( + name="llama-2-70b-test" + model="llama-2-70b", + inference_framework_image_tag="0.9.4", + inference_framework=LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + num_shards=4, + quantize="bitsandbytes", + checkpoint_path="s3://path/to/checkpoint", + cpus=40, + memory="96Gi", + storage="200Gi", + gpus=4, + min_workers=0, + max_workers=1, + per_worker=10, + endpoint_type=ModelEndpointType.STREAMING, + gpu_type="nvidia-ampere-a10", + public_inference=False, + ) + + print(response.json()) + ``` + """ + post_inference_hooks_strs = None + if post_inference_hooks is not None: + post_inference_hooks_strs = [] + for hook in post_inference_hooks: + if isinstance(hook, PostInferenceHooks): + post_inference_hooks_strs.append(hook.value) + else: + post_inference_hooks_strs.append(hook) + + request = CreateLLMEndpointRequest( + name=name, + model_name=model, + source=source, + inference_framework=inference_framework, + inference_framework_image_tag=inference_framework_image_tag, + num_shards=num_shards, + quantize=quantize, + checkpoint_path=checkpoint_path, + max_model_len=max_model_len, + cpus=cpus, + endpoint_type=ModelEndpointType(endpoint_type), + gpus=gpus, + gpu_type=GpuType(gpu_type) if gpu_type is not None else None, + nodes_per_worker=nodes_per_worker, + labels=labels or {}, + max_workers=max_workers, + memory=memory, + metadata={}, + min_workers=min_workers, + per_worker=per_worker, + high_priority=high_priority, + post_inference_hooks=post_inference_hooks_strs, + # Pydantic automatically validates the url + default_callback_url=default_callback_url, # type: ignore + storage=storage, + public_inference=public_inference, + ) + response = cls.post_sync( + resource_name="v1/llm/model-endpoints", + data=request.dict(), + timeout=DEFAULT_TIMEOUT, + headers=request_headers, + ) + return CreateLLMEndpointResponse.parse_obj(response) + @classmethod def get( cls, model: str, + request_headers: Optional[Dict[str, str]] = None, ) -> GetLLMEndpointResponse: """ Get information about an LLM model. @@ -41,7 +365,7 @@ def get( ```python from llmengine import Model - response = Model.get("llama-7b.suffix.2023-07-18-12-00-00") + response = Model.get("llama-2-7b.suffix.2023-07-18-12-00-00") print(response.json()) ``` @@ -50,9 +374,10 @@ def get( ```json { "id": null, - "name": "llama-7b.suffix.2023-07-18-12-00-00", + "name": "llama-2-7b.suffix.2023-07-18-12-00-00", "model_name": null, "source": "hugging_face", + "status": "READY", "inference_framework": "text_generation_inference", "inference_framework_tag": null, "num_shards": null, @@ -61,11 +386,16 @@ def get( } ``` """ - response = cls._get(f"v1/llm/model-endpoints/{model}", timeout=DEFAULT_TIMEOUT) + response = cls._get( + f"v1/llm/model-endpoints/{model}", timeout=DEFAULT_TIMEOUT, headers=request_headers + ) return GetLLMEndpointResponse.parse_obj(response) @classmethod - def list(cls) -> ListLLMEndpointsResponse: + def list( + cls, + request_headers: Optional[Dict[str, str]] = None, + ) -> ListLLMEndpointsResponse: """ List LLM models available to call inference on. @@ -92,7 +422,7 @@ def list(cls) -> ListLLMEndpointsResponse: "model_endpoints": [ { "id": null, - "name": "llama-7b.suffix.2023-07-18-12-00-00", + "name": "llama-2-7b.suffix.2023-07-18-12-00-00", "model_name": null, "source": "hugging_face", "inference_framework": "text_generation_inference", @@ -103,7 +433,7 @@ def list(cls) -> ListLLMEndpointsResponse: }, { "id": null, - "name": "llama-7b", + "name": "llama-2-7b", "model_name": null, "source": "hugging_face", "inference_framework": "text_generation_inference", @@ -138,11 +468,198 @@ def list(cls) -> ListLLMEndpointsResponse: } ``` """ - response = cls._get("v1/llm/model-endpoints", timeout=DEFAULT_TIMEOUT) + response = cls._get( + "v1/llm/model-endpoints", timeout=DEFAULT_TIMEOUT, headers=request_headers + ) return ListLLMEndpointsResponse.parse_obj(response) @classmethod - def delete(cls, model: str) -> DeleteLLMEndpointResponse: + @assert_self_hosted + def update( + cls, + name: str, + # LLM specific fields + model: Optional[str] = None, + inference_framework_image_tag: Optional[str] = None, + source: Optional[LLMSource] = None, + num_shards: Optional[int] = None, + quantize: Optional[Quantization] = None, + checkpoint_path: Optional[str] = None, + # General endpoint fields + cpus: Optional[int] = None, + memory: Optional[str] = None, + storage: Optional[str] = None, + gpus: Optional[int] = None, + min_workers: Optional[int] = None, + max_workers: Optional[int] = None, + per_worker: Optional[int] = None, + endpoint_type: Optional[ModelEndpointType] = None, + gpu_type: Optional[str] = None, + high_priority: Optional[bool] = None, + post_inference_hooks: Optional[List[PostInferenceHooks]] = None, + default_callback_url: Optional[str] = None, + public_inference: Optional[bool] = None, + labels: Optional[Dict[str, str]] = None, + request_headers: Optional[Dict[str, str]] = None, + ) -> UpdateLLMEndpointResponse: + # Can't adjust nodes_per_worker + """ + Update an LLM model. Note: This API is only available for self-hosted users. + + Args: + name (`str`): + Name of the endpoint + + model (`Optional[str]`): + Name of the base model + + inference_framework_image_tag (`Optional[str]`): + Image tag for the inference framework. Use "latest" for the most recent image + + source (`Optional[LLMSource]`): + Source of the LLM. Currently only HuggingFace is supported + + num_shards (`Optional[int]`): + Number of shards for the LLM. When bigger than 1, LLM will be sharded + to multiple GPUs. Number of GPUs must be equal or larger than num_shards. + + quantize (`Optional[Quantization]`): + Quantization method for the LLM. `text_generation_inference` supports `bitsandbytes` and `vllm` supports `awq`. + + checkpoint_path (`Optional[str]`): + Remote path to the checkpoint for the LLM. LLM engine must have permission to access the given path. + Can be either a folder or a tar file. Folder is preferred since we don't need to untar and model loads faster. + For model weights, safetensors are preferred but PyTorch checkpoints are also accepted (model loading will be longer). + + cpus (`Optional[int]`): + Number of cpus each node in the worker should get, e.g. 1, 2, etc. This must be greater + than or equal to 1. Recommendation is set it to 8 * GPU count. + + memory (`Optional[str]`): + Amount of memory each node in the worker should get, e.g. "4Gi", "512Mi", etc. This must + be a positive amount of memory. Recommendation is set it to 24Gi * GPU count. + + storage (`Optional[str]`): + Amount of local ephemeral storage each node in the worker should get, e.g. "4Gi", + "512Mi", etc. This must be a positive amount of storage. + Recommendataion is 40Gi for 7B models, 80Gi for 13B models and 200Gi for 70B models. + + gpus (`Optional[int]`): + Number of gpus each node in the worker should get, e.g. 0, 1, etc. + + min_workers (`Optional[int]`): + The minimum number of workers. Must be greater than or equal to 0. This + should be determined by computing the minimum throughput of your workload and + dividing it by the throughput of a single worker. When this number is 0, + max_workers must be 1, and the endpoint will autoscale between + 0 and 1 pods. When this number is greater than 0, max_workers can be any number + greater or equal to min_workers. + + max_workers (`Optional[int]`): + The maximum number of workers. Must be greater than or equal to 0, + and as well as greater than or equal to ``min_workers``. This should be determined by + computing the maximum throughput of your workload and dividing it by the throughput + of a single worker + + per_worker (`Optional[int]`): + The maximum number of concurrent requests that an individual worker can + service. LLM engine automatically scales the number of workers for the endpoint so that + each worker is processing ``per_worker`` requests, subject to the limits defined by + ``min_workers`` and ``max_workers`` + - If the average number of concurrent requests per worker is lower than + ``per_worker``, then the number of workers will be reduced. - Otherwise, + if the average number of concurrent requests per worker is higher than + ``per_worker``, then the number of workers will be increased to meet the elevated + traffic. + Here is our recommendation for computing ``per_worker``: + 1. Compute ``min_workers`` and ``max_workers`` per your minimum and maximum + throughput requirements. 2. Determine a value for the maximum number of + concurrent requests in the workload. Divide this number by ``max_workers``. Doing + this ensures that the number of workers will "climb" to ``max_workers``. + + endpoint_type (`Optional[ModelEndpointType]`): + Currently only ``"streaming"`` endpoints are supported. + + gpu_type (`Optional[str]`): + If specifying a non-zero number of gpus, this controls the type of gpu + requested. Here are the supported values: + + - ``nvidia-tesla-t4`` + - ``nvidia-ampere-a10`` + - ``nvidia-ampere-a100`` + - ``nvidia-ampere-a100e`` + - ``nvidia-hopper-h100`` + - ``nvidia-hopper-h100-1g20gb`` + - ``nvidia-hopper-h100-3g40gb`` + + high_priority (`Optional[bool]`): + Either ``True`` or ``False``. Enabling this will allow the created + endpoint to leverage the shared pool of prewarmed nodes for faster spinup time + + post_inference_hooks (`Optional[List[PostInferenceHooks]]`): + List of hooks to trigger after inference tasks are served + + default_callback_url (`Optional[str]`): + The default callback url to use for sync completion requests. + This can be overridden in the task parameters for each individual task. + post_inference_hooks must contain "callback" for the callback to be triggered + + public_inference (`Optional[bool]`): + If ``True``, this endpoint will be available to all user IDs for + inference + + labels (`Optional[Dict[str, str]]`): + An optional dictionary of key/value pairs to associate with this endpoint + Returns: + UpdateLLMEndpointResponse: creation task ID of the updated Model. Currently not used. + """ + post_inference_hooks_strs = None + if post_inference_hooks is not None: + post_inference_hooks_strs = [] + for hook in post_inference_hooks: + if isinstance(hook, PostInferenceHooks): + post_inference_hooks_strs.append(hook.value) + else: + post_inference_hooks_strs.append(hook) + + request = UpdateLLMEndpointRequest( + model_name=model, + source=source, + inference_framework_image_tag=inference_framework_image_tag, + num_shards=num_shards, + quantize=quantize, + checkpoint_path=checkpoint_path, + cpus=cpus, + endpoint_type=ModelEndpointType(endpoint_type) if endpoint_type is not None else None, + gpus=gpus, + gpu_type=GpuType(gpu_type) if gpu_type is not None else None, + labels=labels, + max_workers=max_workers, + memory=memory, + metadata={}, + min_workers=min_workers, + per_worker=per_worker, + high_priority=high_priority, + post_inference_hooks=post_inference_hooks_strs, + # Pydantic automatically validates the url + default_callback_url=default_callback_url, # type: ignore + storage=storage, + public_inference=public_inference, + ) + response = cls.put( + resource_name=f"v1/llm/model-endpoints/{name}", + data=request.dict(), + timeout=DEFAULT_TIMEOUT, + headers=request_headers, + ) + return UpdateLLMEndpointResponse.parse_obj(response) + + @classmethod + def delete( + cls, + model_endpoint_name: str, + request_headers: Optional[Dict[str, str]] = None, + ) -> DeleteLLMEndpointResponse: """ Deletes an LLM model. @@ -153,17 +670,17 @@ def delete(cls, model: str) -> DeleteLLMEndpointResponse: Engine, an error will be thrown. Args: - model (`str`): - Name of the model + model_endpoint_name (`str`): + Name of the model endpoint to be deleted Returns: - response: whether the model was successfully deleted + response: whether the model endpoint was successfully deleted === "Deleting model in Python" ```python from llmengine import Model - response = Model.delete("llama-7b.suffix.2023-07-18-12-00-00") + response = Model.delete("llama-2-7b.suffix.2023-07-18-12-00-00") print(response.json()) ``` @@ -174,5 +691,57 @@ def delete(cls, model: str) -> DeleteLLMEndpointResponse: } ``` """ - response = cls._delete(f"v1/llm/model-endpoints/{model}", timeout=DEFAULT_TIMEOUT) + response = cls._delete( + f"v1/llm/model-endpoints/{model_endpoint_name}", + timeout=DEFAULT_TIMEOUT, + headers=request_headers, + ) return DeleteLLMEndpointResponse.parse_obj(response) + + @classmethod + def download( + cls, + model_name: str, + download_format: str = "hugging_face", + ) -> ModelDownloadResponse: + """ + Download a fine-tuned model. + + This API can be used to download the resulting model from a fine-tuning job. + It takes the `model_name` and `download_format` as parameter and returns a + response object which contains a dictonary of filename, url pairs associated + with the fine-tuned model. The user can then download these urls to obtain + the fine-tuned model. If called on a nonexistent model, an error will be thrown. + + Args: + model_name (`str`): + name of the fine-tuned model + download_format (`str`): + download format requested (default=hugging_face) + Returns: + DownloadModelResponse: an object that contains a dictionary of filenames, urls from which to download the model weights. + The urls are presigned urls that grant temporary access and expire after an hour. + + === "Downloading model in Python" + ```python + from llmengine import Model + + response = Model.download("llama-2-7b.suffix.2023-07-18-12-00-00", download_format="hugging_face") + print(response.json()) + ``` + + === "Response in JSON" + ```json + { + "urls": {"my_model_file": "https://url-to-my-model-weights"} + } + ``` + """ + + request = ModelDownloadRequest(model_name=model_name, download_format=download_format) + response = cls.post_sync( + resource_name="v1/llm/model-endpoints/download", + data=request.dict(), + timeout=DEFAULT_TIMEOUT, + ) + return ModelDownloadResponse.parse_obj(response) diff --git a/server/llm_engine_server/api/__init__.py b/clients/python/llmengine/py.typed similarity index 100% rename from server/llm_engine_server/api/__init__.py rename to clients/python/llmengine/py.typed diff --git a/clients/python/mypy.ini b/clients/python/mypy.ini index f35ae689..53164a06 100644 --- a/clients/python/mypy.ini +++ b/clients/python/mypy.ini @@ -6,3 +6,6 @@ namespace_packages = True explicit_package_bases = True strict_optional = True plugins = pydantic.mypy + +[mypy-llmengine.data_types.gen.*] +ignore_errors = True \ No newline at end of file diff --git a/clients/python/poetry.lock b/clients/python/poetry.lock index 99b869f9..2b23ca33 100644 --- a/clients/python/poetry.lock +++ b/clients/python/poetry.lock @@ -1,114 +1,127 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. + +[[package]] +name = "aiohappyeyeballs" +version = "2.4.0" +description = "Happy Eyeballs for asyncio" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiohappyeyeballs-2.4.0-py3-none-any.whl", hash = "sha256:7ce92076e249169a13c2f49320d1967425eaf1f407522d707d59cac7628d62bd"}, + {file = "aiohappyeyeballs-2.4.0.tar.gz", hash = "sha256:55a1714f084e63d49639800f95716da97a1f173d46a16dfcfda0016abb93b6b2"}, +] [[package]] name = "aiohttp" -version = "3.8.4" +version = "3.10.5" description = "Async http client/server framework (asyncio)" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "aiohttp-3.8.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:5ce45967538fb747370308d3145aa68a074bdecb4f3a300869590f725ced69c1"}, - {file = "aiohttp-3.8.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b744c33b6f14ca26b7544e8d8aadff6b765a80ad6164fb1a430bbadd593dfb1a"}, - {file = "aiohttp-3.8.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:1a45865451439eb320784918617ba54b7a377e3501fb70402ab84d38c2cd891b"}, - {file = "aiohttp-3.8.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a86d42d7cba1cec432d47ab13b6637bee393a10f664c425ea7b305d1301ca1a3"}, - {file = "aiohttp-3.8.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ee3c36df21b5714d49fc4580247947aa64bcbe2939d1b77b4c8dcb8f6c9faecc"}, - {file = "aiohttp-3.8.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:176a64b24c0935869d5bbc4c96e82f89f643bcdf08ec947701b9dbb3c956b7dd"}, - {file = "aiohttp-3.8.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c844fd628851c0bc309f3c801b3a3d58ce430b2ce5b359cd918a5a76d0b20cb5"}, - {file = "aiohttp-3.8.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5393fb786a9e23e4799fec788e7e735de18052f83682ce2dfcabaf1c00c2c08e"}, - {file = "aiohttp-3.8.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e4b09863aae0dc965c3ef36500d891a3ff495a2ea9ae9171e4519963c12ceefd"}, - {file = "aiohttp-3.8.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:adfbc22e87365a6e564c804c58fc44ff7727deea782d175c33602737b7feadb6"}, - {file = "aiohttp-3.8.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:147ae376f14b55f4f3c2b118b95be50a369b89b38a971e80a17c3fd623f280c9"}, - {file = "aiohttp-3.8.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:eafb3e874816ebe2a92f5e155f17260034c8c341dad1df25672fb710627c6949"}, - {file = "aiohttp-3.8.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c6cc15d58053c76eacac5fa9152d7d84b8d67b3fde92709195cb984cfb3475ea"}, - {file = "aiohttp-3.8.4-cp310-cp310-win32.whl", hash = "sha256:59f029a5f6e2d679296db7bee982bb3d20c088e52a2977e3175faf31d6fb75d1"}, - {file = "aiohttp-3.8.4-cp310-cp310-win_amd64.whl", hash = "sha256:fe7ba4a51f33ab275515f66b0a236bcde4fb5561498fe8f898d4e549b2e4509f"}, - {file = "aiohttp-3.8.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:3d8ef1a630519a26d6760bc695842579cb09e373c5f227a21b67dc3eb16cfea4"}, - {file = "aiohttp-3.8.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:5b3f2e06a512e94722886c0827bee9807c86a9f698fac6b3aee841fab49bbfb4"}, - {file = "aiohttp-3.8.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3a80464982d41b1fbfe3154e440ba4904b71c1a53e9cd584098cd41efdb188ef"}, - {file = "aiohttp-3.8.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8b631e26df63e52f7cce0cce6507b7a7f1bc9b0c501fcde69742130b32e8782f"}, - {file = "aiohttp-3.8.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3f43255086fe25e36fd5ed8f2ee47477408a73ef00e804cb2b5cba4bf2ac7f5e"}, - {file = "aiohttp-3.8.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4d347a172f866cd1d93126d9b239fcbe682acb39b48ee0873c73c933dd23bd0f"}, - {file = "aiohttp-3.8.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a3fec6a4cb5551721cdd70473eb009d90935b4063acc5f40905d40ecfea23e05"}, - {file = "aiohttp-3.8.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:80a37fe8f7c1e6ce8f2d9c411676e4bc633a8462844e38f46156d07a7d401654"}, - {file = "aiohttp-3.8.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d1e6a862b76f34395a985b3cd39a0d949ca80a70b6ebdea37d3ab39ceea6698a"}, - {file = "aiohttp-3.8.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:cd468460eefef601ece4428d3cf4562459157c0f6523db89365202c31b6daebb"}, - {file = "aiohttp-3.8.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:618c901dd3aad4ace71dfa0f5e82e88b46ef57e3239fc7027773cb6d4ed53531"}, - {file = "aiohttp-3.8.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:652b1bff4f15f6287550b4670546a2947f2a4575b6c6dff7760eafb22eacbf0b"}, - {file = "aiohttp-3.8.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80575ba9377c5171407a06d0196b2310b679dc752d02a1fcaa2bc20b235dbf24"}, - {file = "aiohttp-3.8.4-cp311-cp311-win32.whl", hash = "sha256:bbcf1a76cf6f6dacf2c7f4d2ebd411438c275faa1dc0c68e46eb84eebd05dd7d"}, - {file = "aiohttp-3.8.4-cp311-cp311-win_amd64.whl", hash = "sha256:6e74dd54f7239fcffe07913ff8b964e28b712f09846e20de78676ce2a3dc0bfc"}, - {file = "aiohttp-3.8.4-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:880e15bb6dad90549b43f796b391cfffd7af373f4646784795e20d92606b7a51"}, - {file = "aiohttp-3.8.4-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:bb96fa6b56bb536c42d6a4a87dfca570ff8e52de2d63cabebfd6fb67049c34b6"}, - {file = "aiohttp-3.8.4-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4a6cadebe132e90cefa77e45f2d2f1a4b2ce5c6b1bfc1656c1ddafcfe4ba8131"}, - {file = "aiohttp-3.8.4-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f352b62b45dff37b55ddd7b9c0c8672c4dd2eb9c0f9c11d395075a84e2c40f75"}, - {file = "aiohttp-3.8.4-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7ab43061a0c81198d88f39aaf90dae9a7744620978f7ef3e3708339b8ed2ef01"}, - {file = "aiohttp-3.8.4-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c9cb1565a7ad52e096a6988e2ee0397f72fe056dadf75d17fa6b5aebaea05622"}, - {file = "aiohttp-3.8.4-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:1b3ea7edd2d24538959c1c1abf97c744d879d4e541d38305f9bd7d9b10c9ec41"}, - {file = "aiohttp-3.8.4-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:7c7837fe8037e96b6dd5cfcf47263c1620a9d332a87ec06a6ca4564e56bd0f36"}, - {file = "aiohttp-3.8.4-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:3b90467ebc3d9fa5b0f9b6489dfb2c304a1db7b9946fa92aa76a831b9d587e99"}, - {file = "aiohttp-3.8.4-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:cab9401de3ea52b4b4c6971db5fb5c999bd4260898af972bf23de1c6b5dd9d71"}, - {file = "aiohttp-3.8.4-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:d1f9282c5f2b5e241034a009779e7b2a1aa045f667ff521e7948ea9b56e0c5ff"}, - {file = "aiohttp-3.8.4-cp36-cp36m-win32.whl", hash = "sha256:5e14f25765a578a0a634d5f0cd1e2c3f53964553a00347998dfdf96b8137f777"}, - {file = "aiohttp-3.8.4-cp36-cp36m-win_amd64.whl", hash = "sha256:4c745b109057e7e5f1848c689ee4fb3a016c8d4d92da52b312f8a509f83aa05e"}, - {file = "aiohttp-3.8.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:aede4df4eeb926c8fa70de46c340a1bc2c6079e1c40ccf7b0eae1313ffd33519"}, - {file = "aiohttp-3.8.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4ddaae3f3d32fc2cb4c53fab020b69a05c8ab1f02e0e59665c6f7a0d3a5be54f"}, - {file = "aiohttp-3.8.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4eb3b82ca349cf6fadcdc7abcc8b3a50ab74a62e9113ab7a8ebc268aad35bb9"}, - {file = "aiohttp-3.8.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9bcb89336efa095ea21b30f9e686763f2be4478f1b0a616969551982c4ee4c3b"}, - {file = "aiohttp-3.8.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c08e8ed6fa3d477e501ec9db169bfac8140e830aa372d77e4a43084d8dd91ab"}, - {file = "aiohttp-3.8.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c6cd05ea06daca6ad6a4ca3ba7fe7dc5b5de063ff4daec6170ec0f9979f6c332"}, - {file = "aiohttp-3.8.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:b7a00a9ed8d6e725b55ef98b1b35c88013245f35f68b1b12c5cd4100dddac333"}, - {file = "aiohttp-3.8.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:de04b491d0e5007ee1b63a309956eaed959a49f5bb4e84b26c8f5d49de140fa9"}, - {file = "aiohttp-3.8.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:40653609b3bf50611356e6b6554e3a331f6879fa7116f3959b20e3528783e699"}, - {file = "aiohttp-3.8.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:dbf3a08a06b3f433013c143ebd72c15cac33d2914b8ea4bea7ac2c23578815d6"}, - {file = "aiohttp-3.8.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:854f422ac44af92bfe172d8e73229c270dc09b96535e8a548f99c84f82dde241"}, - {file = "aiohttp-3.8.4-cp37-cp37m-win32.whl", hash = "sha256:aeb29c84bb53a84b1a81c6c09d24cf33bb8432cc5c39979021cc0f98c1292a1a"}, - {file = "aiohttp-3.8.4-cp37-cp37m-win_amd64.whl", hash = "sha256:db3fc6120bce9f446d13b1b834ea5b15341ca9ff3f335e4a951a6ead31105480"}, - {file = "aiohttp-3.8.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:fabb87dd8850ef0f7fe2b366d44b77d7e6fa2ea87861ab3844da99291e81e60f"}, - {file = "aiohttp-3.8.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:91f6d540163f90bbaef9387e65f18f73ffd7c79f5225ac3d3f61df7b0d01ad15"}, - {file = "aiohttp-3.8.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d265f09a75a79a788237d7f9054f929ced2e69eb0bb79de3798c468d8a90f945"}, - {file = "aiohttp-3.8.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3d89efa095ca7d442a6d0cbc755f9e08190ba40069b235c9886a8763b03785da"}, - {file = "aiohttp-3.8.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4dac314662f4e2aa5009977b652d9b8db7121b46c38f2073bfeed9f4049732cd"}, - {file = "aiohttp-3.8.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fe11310ae1e4cd560035598c3f29d86cef39a83d244c7466f95c27ae04850f10"}, - {file = "aiohttp-3.8.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6ddb2a2026c3f6a68c3998a6c47ab6795e4127315d2e35a09997da21865757f8"}, - {file = "aiohttp-3.8.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e75b89ac3bd27d2d043b234aa7b734c38ba1b0e43f07787130a0ecac1e12228a"}, - {file = "aiohttp-3.8.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6e601588f2b502c93c30cd5a45bfc665faaf37bbe835b7cfd461753068232074"}, - {file = "aiohttp-3.8.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a5d794d1ae64e7753e405ba58e08fcfa73e3fad93ef9b7e31112ef3c9a0efb52"}, - {file = "aiohttp-3.8.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:a1f4689c9a1462f3df0a1f7e797791cd6b124ddbee2b570d34e7f38ade0e2c71"}, - {file = "aiohttp-3.8.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:3032dcb1c35bc330134a5b8a5d4f68c1a87252dfc6e1262c65a7e30e62298275"}, - {file = "aiohttp-3.8.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8189c56eb0ddbb95bfadb8f60ea1b22fcfa659396ea36f6adcc521213cd7b44d"}, - {file = "aiohttp-3.8.4-cp38-cp38-win32.whl", hash = "sha256:33587f26dcee66efb2fff3c177547bd0449ab7edf1b73a7f5dea1e38609a0c54"}, - {file = "aiohttp-3.8.4-cp38-cp38-win_amd64.whl", hash = "sha256:e595432ac259af2d4630008bf638873d69346372d38255774c0e286951e8b79f"}, - {file = "aiohttp-3.8.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5a7bdf9e57126dc345b683c3632e8ba317c31d2a41acd5800c10640387d193ed"}, - {file = "aiohttp-3.8.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:22f6eab15b6db242499a16de87939a342f5a950ad0abaf1532038e2ce7d31567"}, - {file = "aiohttp-3.8.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:7235604476a76ef249bd64cb8274ed24ccf6995c4a8b51a237005ee7a57e8643"}, - {file = "aiohttp-3.8.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ea9eb976ffdd79d0e893869cfe179a8f60f152d42cb64622fca418cd9b18dc2a"}, - {file = "aiohttp-3.8.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:92c0cea74a2a81c4c76b62ea1cac163ecb20fb3ba3a75c909b9fa71b4ad493cf"}, - {file = "aiohttp-3.8.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:493f5bc2f8307286b7799c6d899d388bbaa7dfa6c4caf4f97ef7521b9cb13719"}, - {file = "aiohttp-3.8.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0a63f03189a6fa7c900226e3ef5ba4d3bd047e18f445e69adbd65af433add5a2"}, - {file = "aiohttp-3.8.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:10c8cefcff98fd9168cdd86c4da8b84baaa90bf2da2269c6161984e6737bf23e"}, - {file = "aiohttp-3.8.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:bca5f24726e2919de94f047739d0a4fc01372801a3672708260546aa2601bf57"}, - {file = "aiohttp-3.8.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:03baa76b730e4e15a45f81dfe29a8d910314143414e528737f8589ec60cf7391"}, - {file = "aiohttp-3.8.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:8c29c77cc57e40f84acef9bfb904373a4e89a4e8b74e71aa8075c021ec9078c2"}, - {file = "aiohttp-3.8.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:03543dcf98a6619254b409be2d22b51f21ec66272be4ebda7b04e6412e4b2e14"}, - {file = "aiohttp-3.8.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:17b79c2963db82086229012cff93ea55196ed31f6493bb1ccd2c62f1724324e4"}, - {file = "aiohttp-3.8.4-cp39-cp39-win32.whl", hash = "sha256:34ce9f93a4a68d1272d26030655dd1b58ff727b3ed2a33d80ec433561b03d67a"}, - {file = "aiohttp-3.8.4-cp39-cp39-win_amd64.whl", hash = "sha256:41a86a69bb63bb2fc3dc9ad5ea9f10f1c9c8e282b471931be0268ddd09430b04"}, - {file = "aiohttp-3.8.4.tar.gz", hash = "sha256:bf2e1a9162c1e441bf805a1fd166e249d574ca04e03b34f97e2928769e91ab5c"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:18a01eba2574fb9edd5f6e5fb25f66e6ce061da5dab5db75e13fe1558142e0a3"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:94fac7c6e77ccb1ca91e9eb4cb0ac0270b9fb9b289738654120ba8cebb1189c6"}, + {file = "aiohttp-3.10.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2f1f1c75c395991ce9c94d3e4aa96e5c59c8356a15b1c9231e783865e2772699"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4f7acae3cf1a2a2361ec4c8e787eaaa86a94171d2417aae53c0cca6ca3118ff6"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:94c4381ffba9cc508b37d2e536b418d5ea9cfdc2848b9a7fea6aebad4ec6aac1"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c31ad0c0c507894e3eaa843415841995bf8de4d6b2d24c6e33099f4bc9fc0d4f"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0912b8a8fadeb32ff67a3ed44249448c20148397c1ed905d5dac185b4ca547bb"}, + {file = "aiohttp-3.10.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0d93400c18596b7dc4794d48a63fb361b01a0d8eb39f28800dc900c8fbdaca91"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:d00f3c5e0d764a5c9aa5a62d99728c56d455310bcc288a79cab10157b3af426f"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:d742c36ed44f2798c8d3f4bc511f479b9ceef2b93f348671184139e7d708042c"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:814375093edae5f1cb31e3407997cf3eacefb9010f96df10d64829362ae2df69"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:8224f98be68a84b19f48e0bdc14224b5a71339aff3a27df69989fa47d01296f3"}, + {file = "aiohttp-3.10.5-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:d9a487ef090aea982d748b1b0d74fe7c3950b109df967630a20584f9a99c0683"}, + {file = "aiohttp-3.10.5-cp310-cp310-win32.whl", hash = "sha256:d9ef084e3dc690ad50137cc05831c52b6ca428096e6deb3c43e95827f531d5ef"}, + {file = "aiohttp-3.10.5-cp310-cp310-win_amd64.whl", hash = "sha256:66bf9234e08fe561dccd62083bf67400bdbf1c67ba9efdc3dac03650e97c6088"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:8c6a4e5e40156d72a40241a25cc226051c0a8d816610097a8e8f517aeacd59a2"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:2c634a3207a5445be65536d38c13791904fda0748b9eabf908d3fe86a52941cf"}, + {file = "aiohttp-3.10.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:4aff049b5e629ef9b3e9e617fa6e2dfeda1bf87e01bcfecaf3949af9e210105e"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1942244f00baaacaa8155eca94dbd9e8cc7017deb69b75ef67c78e89fdad3c77"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e04a1f2a65ad2f93aa20f9ff9f1b672bf912413e5547f60749fa2ef8a644e061"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7f2bfc0032a00405d4af2ba27f3c429e851d04fad1e5ceee4080a1c570476697"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:424ae21498790e12eb759040bbb504e5e280cab64693d14775c54269fd1d2bb7"}, + {file = "aiohttp-3.10.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:975218eee0e6d24eb336d0328c768ebc5d617609affaca5dbbd6dd1984f16ed0"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:4120d7fefa1e2d8fb6f650b11489710091788de554e2b6f8347c7a20ceb003f5"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:b90078989ef3fc45cf9221d3859acd1108af7560c52397ff4ace8ad7052a132e"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:ba5a8b74c2a8af7d862399cdedce1533642fa727def0b8c3e3e02fcb52dca1b1"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:02594361128f780eecc2a29939d9dfc870e17b45178a867bf61a11b2a4367277"}, + {file = "aiohttp-3.10.5-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:8fb4fc029e135859f533025bc82047334e24b0d489e75513144f25408ecaf058"}, + {file = "aiohttp-3.10.5-cp311-cp311-win32.whl", hash = "sha256:e1ca1ef5ba129718a8fc827b0867f6aa4e893c56eb00003b7367f8a733a9b072"}, + {file = "aiohttp-3.10.5-cp311-cp311-win_amd64.whl", hash = "sha256:349ef8a73a7c5665cca65c88ab24abe75447e28aa3bc4c93ea5093474dfdf0ff"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:305be5ff2081fa1d283a76113b8df7a14c10d75602a38d9f012935df20731487"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:3a1c32a19ee6bbde02f1cb189e13a71b321256cc1d431196a9f824050b160d5a"}, + {file = "aiohttp-3.10.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:61645818edd40cc6f455b851277a21bf420ce347baa0b86eaa41d51ef58ba23d"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6c225286f2b13bab5987425558baa5cbdb2bc925b2998038fa028245ef421e75"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ba01ebc6175e1e6b7275c907a3a36be48a2d487549b656aa90c8a910d9f3178"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8eaf44ccbc4e35762683078b72bf293f476561d8b68ec8a64f98cf32811c323e"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1c43eb1ab7cbf411b8e387dc169acb31f0ca0d8c09ba63f9eac67829585b44f"}, + {file = "aiohttp-3.10.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:de7a5299827253023c55ea549444e058c0eb496931fa05d693b95140a947cb73"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:4790f0e15f00058f7599dab2b206d3049d7ac464dc2e5eae0e93fa18aee9e7bf"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:44b324a6b8376a23e6ba25d368726ee3bc281e6ab306db80b5819999c737d820"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:0d277cfb304118079e7044aad0b76685d30ecb86f83a0711fc5fb257ffe832ca"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:54d9ddea424cd19d3ff6128601a4a4d23d54a421f9b4c0fff740505813739a91"}, + {file = "aiohttp-3.10.5-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:4f1c9866ccf48a6df2b06823e6ae80573529f2af3a0992ec4fe75b1a510df8a6"}, + {file = "aiohttp-3.10.5-cp312-cp312-win32.whl", hash = "sha256:dc4826823121783dccc0871e3f405417ac116055bf184ac04c36f98b75aacd12"}, + {file = "aiohttp-3.10.5-cp312-cp312-win_amd64.whl", hash = "sha256:22c0a23a3b3138a6bf76fc553789cb1a703836da86b0f306b6f0dc1617398abc"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:7f6b639c36734eaa80a6c152a238242bedcee9b953f23bb887e9102976343092"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:f29930bc2921cef955ba39a3ff87d2c4398a0394ae217f41cb02d5c26c8b1b77"}, + {file = "aiohttp-3.10.5-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:f489a2c9e6455d87eabf907ac0b7d230a9786be43fbe884ad184ddf9e9c1e385"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:123dd5b16b75b2962d0fff566effb7a065e33cd4538c1692fb31c3bda2bfb972"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b98e698dc34966e5976e10bbca6d26d6724e6bdea853c7c10162a3235aba6e16"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c3b9162bab7e42f21243effc822652dc5bb5e8ff42a4eb62fe7782bcbcdfacf6"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1923a5c44061bffd5eebeef58cecf68096e35003907d8201a4d0d6f6e387ccaa"}, + {file = "aiohttp-3.10.5-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d55f011da0a843c3d3df2c2cf4e537b8070a419f891c930245f05d329c4b0689"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:afe16a84498441d05e9189a15900640a2d2b5e76cf4efe8cbb088ab4f112ee57"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:f8112fb501b1e0567a1251a2fd0747baae60a4ab325a871e975b7bb67e59221f"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:1e72589da4c90337837fdfe2026ae1952c0f4a6e793adbbfbdd40efed7c63599"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_s390x.whl", hash = "sha256:4d46c7b4173415d8e583045fbc4daa48b40e31b19ce595b8d92cf639396c15d5"}, + {file = "aiohttp-3.10.5-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:33e6bc4bab477c772a541f76cd91e11ccb6d2efa2b8d7d7883591dfb523e5987"}, + {file = "aiohttp-3.10.5-cp313-cp313-win32.whl", hash = "sha256:c58c6837a2c2a7cf3133983e64173aec11f9c2cd8e87ec2fdc16ce727bcf1a04"}, + {file = "aiohttp-3.10.5-cp313-cp313-win_amd64.whl", hash = "sha256:38172a70005252b6893088c0f5e8a47d173df7cc2b2bd88650957eb84fcf5022"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:f6f18898ace4bcd2d41a122916475344a87f1dfdec626ecde9ee802a711bc569"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:5ede29d91a40ba22ac1b922ef510aab871652f6c88ef60b9dcdf773c6d32ad7a"}, + {file = "aiohttp-3.10.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:673f988370f5954df96cc31fd99c7312a3af0a97f09e407399f61583f30da9bc"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:58718e181c56a3c02d25b09d4115eb02aafe1a732ce5714ab70326d9776457c3"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:4b38b1570242fbab8d86a84128fb5b5234a2f70c2e32f3070143a6d94bc854cf"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:074d1bff0163e107e97bd48cad9f928fa5a3eb4b9d33366137ffce08a63e37fe"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd31f176429cecbc1ba499d4aba31aaccfea488f418d60376b911269d3b883c5"}, + {file = "aiohttp-3.10.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7384d0b87d4635ec38db9263e6a3f1eb609e2e06087f0aa7f63b76833737b471"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:8989f46f3d7ef79585e98fa991e6ded55d2f48ae56d2c9fa5e491a6e4effb589"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:c83f7a107abb89a227d6c454c613e7606c12a42b9a4ca9c5d7dad25d47c776ae"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:cde98f323d6bf161041e7627a5fd763f9fd829bcfcd089804a5fdce7bb6e1b7d"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:676f94c5480d8eefd97c0c7e3953315e4d8c2b71f3b49539beb2aa676c58272f"}, + {file = "aiohttp-3.10.5-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:2d21ac12dc943c68135ff858c3a989f2194a709e6e10b4c8977d7fcd67dfd511"}, + {file = "aiohttp-3.10.5-cp38-cp38-win32.whl", hash = "sha256:17e997105bd1a260850272bfb50e2a328e029c941c2708170d9d978d5a30ad9a"}, + {file = "aiohttp-3.10.5-cp38-cp38-win_amd64.whl", hash = "sha256:1c19de68896747a2aa6257ae4cf6ef59d73917a36a35ee9d0a6f48cff0f94db8"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:7e2fe37ac654032db1f3499fe56e77190282534810e2a8e833141a021faaab0e"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:f5bf3ead3cb66ab990ee2561373b009db5bc0e857549b6c9ba84b20bc462e172"}, + {file = "aiohttp-3.10.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1b2c16a919d936ca87a3c5f0e43af12a89a3ce7ccbce59a2d6784caba945b68b"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad146dae5977c4dd435eb31373b3fe9b0b1bf26858c6fc452bf6af394067e10b"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8c5c6fa16412b35999320f5c9690c0f554392dc222c04e559217e0f9ae244b92"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:95c4dc6f61d610bc0ee1edc6f29d993f10febfe5b76bb470b486d90bbece6b22"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:da452c2c322e9ce0cfef392e469a26d63d42860f829026a63374fde6b5c5876f"}, + {file = "aiohttp-3.10.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:898715cf566ec2869d5cb4d5fb4be408964704c46c96b4be267442d265390f32"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:391cc3a9c1527e424c6865e087897e766a917f15dddb360174a70467572ac6ce"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:380f926b51b92d02a34119d072f178d80bbda334d1a7e10fa22d467a66e494db"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ce91db90dbf37bb6fa0997f26574107e1b9d5ff939315247b7e615baa8ec313b"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:9093a81e18c45227eebe4c16124ebf3e0d893830c6aca7cc310bfca8fe59d857"}, + {file = "aiohttp-3.10.5-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:ee40b40aa753d844162dcc80d0fe256b87cba48ca0054f64e68000453caead11"}, + {file = "aiohttp-3.10.5-cp39-cp39-win32.whl", hash = "sha256:03f2645adbe17f274444953bdea69f8327e9d278d961d85657cb0d06864814c1"}, + {file = "aiohttp-3.10.5-cp39-cp39-win_amd64.whl", hash = "sha256:d17920f18e6ee090bdd3d0bfffd769d9f2cb4c8ffde3eb203777a3895c128862"}, + {file = "aiohttp-3.10.5.tar.gz", hash = "sha256:f071854b47d39591ce9a17981c46790acb30518e2f83dfca8db2dfa091178691"}, ] [package.dependencies] +aiohappyeyeballs = ">=2.3.0" aiosignal = ">=1.1.2" -async-timeout = ">=4.0.0a3,<5.0" -asynctest = {version = "0.13.0", markers = "python_version < \"3.8\""} +async-timeout = {version = ">=4.0,<5.0", markers = "python_version < \"3.11\""} attrs = ">=17.3.0" -charset-normalizer = ">=2.0,<4.0" frozenlist = ">=1.1.1" multidict = ">=4.5,<7.0" -typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""} yarl = ">=1.0,<2.0" [package.extras] -speedups = ["Brotli", "aiodns", "cchardet"] +speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] [[package]] name = "aiosignal" @@ -125,28 +138,50 @@ files = [ frozenlist = ">=1.1.0" [[package]] -name = "async-timeout" -version = "4.0.2" -description = "Timeout context manager for asyncio programs" +name = "annotated-types" +version = "0.7.0" +description = "Reusable constraint types to use with typing.Annotated" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "async-timeout-4.0.2.tar.gz", hash = "sha256:2163e1640ddb52b7a8c80d0a67a08587e5d245cc9c553a74a847056bc2976b15"}, - {file = "async_timeout-4.0.2-py3-none-any.whl", hash = "sha256:8ca1e4fcf50d07413d66d1a5e416e42cfdf5851c981d679a09851a6853383b3c"}, + {file = "annotated_types-0.7.0-py3-none-any.whl", hash = "sha256:1f02e8b43a8fbbc3f3e0d4f0f4bfc8131bcb4eebe8849b8e5c773f3a1c582a53"}, + {file = "annotated_types-0.7.0.tar.gz", hash = "sha256:aff07c09a53a08bc8cfccb9c85b05f1aa9a2a6f23728d790723543408344ce89"}, ] [package.dependencies] -typing-extensions = {version = ">=3.6.5", markers = "python_version < \"3.8\""} +typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.9\""} [[package]] -name = "asynctest" -version = "0.13.0" -description = "Enhance the standard unittest package with features for testing asyncio libraries" +name = "anyio" +version = "4.4.0" +description = "High level compatibility layer for multiple asynchronous event loop implementations" optional = false -python-versions = ">=3.5" +python-versions = ">=3.8" files = [ - {file = "asynctest-0.13.0-py3-none-any.whl", hash = "sha256:5da6118a7e6d6b54d83a8f7197769d046922a44d2a99c21382f0a6e4fadae676"}, - {file = "asynctest-0.13.0.tar.gz", hash = "sha256:c27862842d15d83e6a34eb0b2866c323880eb3a75e4485b079ea11748fd77fac"}, + {file = "anyio-4.4.0-py3-none-any.whl", hash = "sha256:c1b2d8f46a8a812513012e1107cb0e68c17159a7a594208005a57dc776e1bdc7"}, + {file = "anyio-4.4.0.tar.gz", hash = "sha256:5aadc6a1bbb7cdb0bede386cac5e2940f5e2ff3aa20277e991cf028e0585ce94"}, +] + +[package.dependencies] +exceptiongroup = {version = ">=1.0.2", markers = "python_version < \"3.11\""} +idna = ">=2.8" +sniffio = ">=1.1" +typing-extensions = {version = ">=4.1", markers = "python_version < \"3.11\""} + +[package.extras] +doc = ["Sphinx (>=7)", "packaging", "sphinx-autodoc-typehints (>=1.2.0)", "sphinx-rtd-theme"] +test = ["anyio[trio]", "coverage[toml] (>=7)", "exceptiongroup (>=1.2.0)", "hypothesis (>=4.0)", "psutil (>=5.9)", "pytest (>=7.0)", "pytest-mock (>=3.6.1)", "trustme", "uvloop (>=0.17)"] +trio = ["trio (>=0.23)"] + +[[package]] +name = "async-timeout" +version = "4.0.3" +description = "Timeout context manager for asyncio programs" +optional = false +python-versions = ">=3.7" +files = [ + {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, + {file = "async_timeout-4.0.3-py3-none-any.whl", hash = "sha256:7405140ff1230c310e51dc27b3145b9092d659ce68ff733fb0cefe3ee42be028"}, ] [[package]] @@ -161,118 +196,131 @@ files = [ [[package]] name = "attrs" -version = "23.1.0" +version = "24.2.0" description = "Classes Without Boilerplate" optional = false python-versions = ">=3.7" files = [ - {file = "attrs-23.1.0-py3-none-any.whl", hash = "sha256:1f28b4522cdc2fb4256ac1a020c78acf9cba2c6b461ccd2c126f3aa8e8335d04"}, - {file = "attrs-23.1.0.tar.gz", hash = "sha256:6279836d581513a26f1bf235f9acd333bc9115683f14f7e8fae46c98fc50e015"}, + {file = "attrs-24.2.0-py3-none-any.whl", hash = "sha256:81921eb96de3191c8258c199618104dd27ac608d9366f5e35d011eae1867ede2"}, + {file = "attrs-24.2.0.tar.gz", hash = "sha256:5cfb1b9148b5b086569baec03f20d7b6bf3bcacc9a42bebf87ffaaca362f6346"}, ] -[package.dependencies] -importlib-metadata = {version = "*", markers = "python_version < \"3.8\""} - [package.extras] -cov = ["attrs[tests]", "coverage[toml] (>=5.3)"] -dev = ["attrs[docs,tests]", "pre-commit"] -docs = ["furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier", "zope-interface"] -tests = ["attrs[tests-no-zope]", "zope-interface"] -tests-no-zope = ["cloudpickle", "hypothesis", "mypy (>=1.1.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +benchmark = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-codspeed", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +cov = ["cloudpickle", "coverage[toml] (>=5.3)", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +dev = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pre-commit", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +docs = ["cogapp", "furo", "myst-parser", "sphinx", "sphinx-notfound-page", "sphinxcontrib-towncrier", "towncrier (<24.7)"] +tests = ["cloudpickle", "hypothesis", "mypy (>=1.11.1)", "pympler", "pytest (>=4.3.0)", "pytest-mypy-plugins", "pytest-xdist[psutil]"] +tests-mypy = ["mypy (>=1.11.1)", "pytest-mypy-plugins"] [[package]] name = "certifi" -version = "2023.5.7" +version = "2024.7.4" description = "Python package for providing Mozilla's CA Bundle." optional = false python-versions = ">=3.6" files = [ - {file = "certifi-2023.5.7-py3-none-any.whl", hash = "sha256:c6c2e98f5c7869efca1f8916fed228dd91539f9f1b444c314c06eef02980c716"}, - {file = "certifi-2023.5.7.tar.gz", hash = "sha256:0f0d56dc5a6ad56fd4ba36484d6cc34451e1c6548c61daad8c320169f91eddc7"}, + {file = "certifi-2024.7.4-py3-none-any.whl", hash = "sha256:c198e21b1289c2ab85ee4e67bb4b4ef3ead0892059901a8d5b622f24a1101e90"}, + {file = "certifi-2024.7.4.tar.gz", hash = "sha256:5a1e7645bc0ec61a09e26c36f6106dd4cf40c6db3a1fb6352b0244e7fb057c7b"}, ] [[package]] name = "charset-normalizer" -version = "3.2.0" +version = "3.3.2" description = "The Real First Universal Charset Detector. Open, modern and actively maintained alternative to Chardet." optional = false python-versions = ">=3.7.0" files = [ - {file = "charset-normalizer-3.2.0.tar.gz", hash = "sha256:3bb3d25a8e6c0aedd251753a79ae98a093c7e7b471faa3aa9a93a81431987ace"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b87549028f680ca955556e3bd57013ab47474c3124dc069faa0b6545b6c9710"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:7c70087bfee18a42b4040bb9ec1ca15a08242cf5867c58726530bdf3945672ed"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a103b3a7069b62f5d4890ae1b8f0597618f628b286b03d4bc9195230b154bfa9"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:94aea8eff76ee6d1cdacb07dd2123a68283cb5569e0250feab1240058f53b623"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:db901e2ac34c931d73054d9797383d0f8009991e723dab15109740a63e7f902a"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b0dac0ff919ba34d4df1b6131f59ce95b08b9065233446be7e459f95554c0dc8"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:193cbc708ea3aca45e7221ae58f0fd63f933753a9bfb498a3b474878f12caaad"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09393e1b2a9461950b1c9a45d5fd251dc7c6f228acab64da1c9c0165d9c7765c"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:baacc6aee0b2ef6f3d308e197b5d7a81c0e70b06beae1f1fcacffdbd124fe0e3"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bf420121d4c8dce6b889f0e8e4ec0ca34b7f40186203f06a946fa0276ba54029"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:c04a46716adde8d927adb9457bbe39cf473e1e2c2f5d0a16ceb837e5d841ad4f"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:aaf63899c94de41fe3cf934601b0f7ccb6b428c6e4eeb80da72c58eab077b19a"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62e51710986674142526ab9f78663ca2b0726066ae26b78b22e0f5e571238dd"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-win32.whl", hash = "sha256:04e57ab9fbf9607b77f7d057974694b4f6b142da9ed4a199859d9d4d5c63fe96"}, - {file = "charset_normalizer-3.2.0-cp310-cp310-win_amd64.whl", hash = "sha256:48021783bdf96e3d6de03a6e39a1171ed5bd7e8bb93fc84cc649d11490f87cea"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:4957669ef390f0e6719db3613ab3a7631e68424604a7b448f079bee145da6e09"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:46fb8c61d794b78ec7134a715a3e564aafc8f6b5e338417cb19fe9f57a5a9bf2"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f779d3ad205f108d14e99bb3859aa7dd8e9c68874617c72354d7ecaec2a054ac"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f25c229a6ba38a35ae6e25ca1264621cc25d4d38dca2942a7fce0b67a4efe918"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2efb1bd13885392adfda4614c33d3b68dee4921fd0ac1d3988f8cbb7d589e72a"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f30b48dd7fa1474554b0b0f3fdfdd4c13b5c737a3c6284d3cdc424ec0ffff3a"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:246de67b99b6851627d945db38147d1b209a899311b1305dd84916f2b88526c6"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bd9b3b31adcb054116447ea22caa61a285d92e94d710aa5ec97992ff5eb7cf3"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:8c2f5e83493748286002f9369f3e6607c565a6a90425a3a1fef5ae32a36d749d"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:3170c9399da12c9dc66366e9d14da8bf7147e1e9d9ea566067bbce7bb74bd9c2"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7a4826ad2bd6b07ca615c74ab91f32f6c96d08f6fcc3902ceeedaec8cdc3bcd6"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:3b1613dd5aee995ec6d4c69f00378bbd07614702a315a2cf6c1d21461fe17c23"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9e608aafdb55eb9f255034709e20d5a83b6d60c054df0802fa9c9883d0a937aa"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-win32.whl", hash = "sha256:f2a1d0fd4242bd8643ce6f98927cf9c04540af6efa92323e9d3124f57727bfc1"}, - {file = "charset_normalizer-3.2.0-cp311-cp311-win_amd64.whl", hash = "sha256:681eb3d7e02e3c3655d1b16059fbfb605ac464c834a0c629048a30fad2b27489"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:c57921cda3a80d0f2b8aec7e25c8aa14479ea92b5b51b6876d975d925a2ea346"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41b25eaa7d15909cf3ac4c96088c1f266a9a93ec44f87f1d13d4a0e86c81b982"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f058f6963fd82eb143c692cecdc89e075fa0828db2e5b291070485390b2f1c9c"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7647ebdfb9682b7bb97e2a5e7cb6ae735b1c25008a70b906aecca294ee96cf4"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:eef9df1eefada2c09a5e7a40991b9fc6ac6ef20b1372abd48d2794a316dc0449"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e03b8895a6990c9ab2cdcd0f2fe44088ca1c65ae592b8f795c3294af00a461c3"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:ee4006268ed33370957f55bf2e6f4d263eaf4dc3cfc473d1d90baff6ed36ce4a"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c4983bf937209c57240cff65906b18bb35e64ae872da6a0db937d7b4af845dd7"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:3bb7fda7260735efe66d5107fb7e6af6a7c04c7fce9b2514e04b7a74b06bf5dd"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:72814c01533f51d68702802d74f77ea026b5ec52793c791e2da806a3844a46c3"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:70c610f6cbe4b9fce272c407dd9d07e33e6bf7b4aa1b7ffb6f6ded8e634e3592"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-win32.whl", hash = "sha256:a401b4598e5d3f4a9a811f3daf42ee2291790c7f9d74b18d75d6e21dda98a1a1"}, - {file = "charset_normalizer-3.2.0-cp37-cp37m-win_amd64.whl", hash = "sha256:c0b21078a4b56965e2b12f247467b234734491897e99c1d51cee628da9786959"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:95eb302ff792e12aba9a8b8f8474ab229a83c103d74a750ec0bd1c1eea32e669"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1a100c6d595a7f316f1b6f01d20815d916e75ff98c27a01ae817439ea7726329"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6339d047dab2780cc6220f46306628e04d9750f02f983ddb37439ca47ced7149"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e4b749b9cc6ee664a3300bb3a273c1ca8068c46be705b6c31cf5d276f8628a94"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a38856a971c602f98472050165cea2cdc97709240373041b69030be15047691f"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f87f746ee241d30d6ed93969de31e5ffd09a2961a051e60ae6bddde9ec3583aa"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89f1b185a01fe560bc8ae5f619e924407efca2191b56ce749ec84982fc59a32a"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e1c8a2f4c69e08e89632defbfabec2feb8a8d99edc9f89ce33c4b9e36ab63037"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2f4ac36d8e2b4cc1aa71df3dd84ff8efbe3bfb97ac41242fbcfc053c67434f46"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a386ebe437176aab38c041de1260cd3ea459c6ce5263594399880bbc398225b2"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:ccd16eb18a849fd8dcb23e23380e2f0a354e8daa0c984b8a732d9cfaba3a776d"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:e6a5bf2cba5ae1bb80b154ed68a3cfa2fa00fde979a7f50d6598d3e17d9ac20c"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:45de3f87179c1823e6d9e32156fb14c1927fcc9aba21433f088fdfb555b77c10"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-win32.whl", hash = "sha256:1000fba1057b92a65daec275aec30586c3de2401ccdcd41f8a5c1e2c87078706"}, - {file = "charset_normalizer-3.2.0-cp38-cp38-win_amd64.whl", hash = "sha256:8b2c760cfc7042b27ebdb4a43a4453bd829a5742503599144d54a032c5dc7e9e"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:855eafa5d5a2034b4621c74925d89c5efef61418570e5ef9b37717d9c796419c"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:203f0c8871d5a7987be20c72442488a0b8cfd0f43b7973771640fc593f56321f"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e857a2232ba53ae940d3456f7533ce6ca98b81917d47adc3c7fd55dad8fab858"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e86d77b090dbddbe78867a0275cb4df08ea195e660f1f7f13435a4649e954e5"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c4fb39a81950ec280984b3a44f5bd12819953dc5fa3a7e6fa7a80db5ee853952"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2dee8e57f052ef5353cf608e0b4c871aee320dd1b87d351c28764fc0ca55f9f4"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8700f06d0ce6f128de3ccdbc1acaea1ee264d2caa9ca05daaf492fde7c2a7200"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1920d4ff15ce893210c1f0c0e9d19bfbecb7983c76b33f046c13a8ffbd570252"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:c1c76a1743432b4b60ab3358c937a3fe1341c828ae6194108a94c69028247f22"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:f7560358a6811e52e9c4d142d497f1a6e10103d3a6881f18d04dbce3729c0e2c"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c8063cf17b19661471ecbdb3df1c84f24ad2e389e326ccaf89e3fb2484d8dd7e"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:cd6dbe0238f7743d0efe563ab46294f54f9bc8f4b9bcf57c3c666cc5bc9d1299"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:1249cbbf3d3b04902ff081ffbb33ce3377fa6e4c7356f759f3cd076cc138d020"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-win32.whl", hash = "sha256:6c409c0deba34f147f77efaa67b8e4bb83d2f11c8806405f76397ae5b8c0d1c9"}, - {file = "charset_normalizer-3.2.0-cp39-cp39-win_amd64.whl", hash = "sha256:7095f6fbfaa55defb6b733cfeb14efaae7a29f0b59d8cf213be4e7ca0b857b80"}, - {file = "charset_normalizer-3.2.0-py3-none-any.whl", hash = "sha256:8e098148dd37b4ce3baca71fb394c81dc5d9c7728c95df695d2dca218edf40e6"}, + {file = "charset-normalizer-3.3.2.tar.gz", hash = "sha256:f30c3cb33b24454a82faecaf01b19c18562b1e89558fb6c56de4d9118a032fd5"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:25baf083bf6f6b341f4121c2f3c548875ee6f5339300e08be3f2b2ba1721cdd3"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:06435b539f889b1f6f4ac1758871aae42dc3a8c0e24ac9e60c2384973ad73027"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9063e24fdb1e498ab71cb7419e24622516c4a04476b17a2dab57e8baa30d6e03"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6897af51655e3691ff853668779c7bad41579facacf5fd7253b0133308cf000d"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1d3193f4a680c64b4b6a9115943538edb896edc190f0b222e73761716519268e"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cd70574b12bb8a4d2aaa0094515df2463cb429d8536cfb6c7ce983246983e5a6"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8465322196c8b4d7ab6d1e049e4c5cb460d0394da4a27d23cc242fbf0034b6b5"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a9a8e9031d613fd2009c182b69c7b2c1ef8239a0efb1df3f7c8da66d5dd3d537"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:beb58fe5cdb101e3a055192ac291b7a21e3b7ef4f67fa1d74e331a7f2124341c"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e06ed3eb3218bc64786f7db41917d4e686cc4856944f53d5bdf83a6884432e12"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:2e81c7b9c8979ce92ed306c249d46894776a909505d8f5a4ba55b14206e3222f"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:572c3763a264ba47b3cf708a44ce965d98555f618ca42c926a9c1616d8f34269"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:fd1abc0d89e30cc4e02e4064dc67fcc51bd941eb395c502aac3ec19fab46b519"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-win32.whl", hash = "sha256:3d47fa203a7bd9c5b6cee4736ee84ca03b8ef23193c0d1ca99b5089f72645c73"}, + {file = "charset_normalizer-3.3.2-cp310-cp310-win_amd64.whl", hash = "sha256:10955842570876604d404661fbccbc9c7e684caf432c09c715ec38fbae45ae09"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:802fe99cca7457642125a8a88a084cef28ff0cf9407060f7b93dca5aa25480db"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:573f6eac48f4769d667c4442081b1794f52919e7edada77495aaed9236d13a96"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:549a3a73da901d5bc3ce8d24e0600d1fa85524c10287f6004fbab87672bf3e1e"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f27273b60488abe721a075bcca6d7f3964f9f6f067c8c4c605743023d7d3944f"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ceae2f17a9c33cb48e3263960dc5fc8005351ee19db217e9b1bb15d28c02574"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:65f6f63034100ead094b8744b3b97965785388f308a64cf8d7c34f2f2e5be0c4"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:753f10e867343b4511128c6ed8c82f7bec3bd026875576dfd88483c5c73b2fd8"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:4a78b2b446bd7c934f5dcedc588903fb2f5eec172f3d29e52a9096a43722adfc"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e537484df0d8f426ce2afb2d0f8e1c3d0b114b83f8850e5f2fbea0e797bd82ae"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:eb6904c354526e758fda7167b33005998fb68c46fbc10e013ca97f21ca5c8887"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:deb6be0ac38ece9ba87dea880e438f25ca3eddfac8b002a2ec3d9183a454e8ae"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:4ab2fe47fae9e0f9dee8c04187ce5d09f48eabe611be8259444906793ab7cbce"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:80402cd6ee291dcb72644d6eac93785fe2c8b9cb30893c1af5b8fdd753b9d40f"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-win32.whl", hash = "sha256:7cd13a2e3ddeed6913a65e66e94b51d80a041145a026c27e6bb76c31a853c6ab"}, + {file = "charset_normalizer-3.3.2-cp311-cp311-win_amd64.whl", hash = "sha256:663946639d296df6a2bb2aa51b60a2454ca1cb29835324c640dafb5ff2131a77"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0b2b64d2bb6d3fb9112bafa732def486049e63de9618b5843bcdd081d8144cd8"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:ddbb2551d7e0102e7252db79ba445cdab71b26640817ab1e3e3648dad515003b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:55086ee1064215781fff39a1af09518bc9255b50d6333f2e4c74ca09fac6a8f6"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8f4a014bc36d3c57402e2977dada34f9c12300af536839dc38c0beab8878f38a"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a10af20b82360ab00827f916a6058451b723b4e65030c5a18577c8b2de5b3389"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8d756e44e94489e49571086ef83b2bb8ce311e730092d2c34ca8f7d925cb20aa"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:90d558489962fd4918143277a773316e56c72da56ec7aa3dc3dbbe20fdfed15b"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6ac7ffc7ad6d040517be39eb591cac5ff87416c2537df6ba3cba3bae290c0fed"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7ed9e526742851e8d5cc9e6cf41427dfc6068d4f5a3bb03659444b4cabf6bc26"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:8bdb58ff7ba23002a4c5808d608e4e6c687175724f54a5dade5fa8c67b604e4d"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:6b3251890fff30ee142c44144871185dbe13b11bab478a88887a639655be1068"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:b4a23f61ce87adf89be746c8a8974fe1c823c891d8f86eb218bb957c924bb143"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:efcb3f6676480691518c177e3b465bcddf57cea040302f9f4e6e191af91174d4"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win32.whl", hash = "sha256:d965bba47ddeec8cd560687584e88cf699fd28f192ceb452d1d7ee807c5597b7"}, + {file = "charset_normalizer-3.3.2-cp312-cp312-win_amd64.whl", hash = "sha256:96b02a3dc4381e5494fad39be677abcb5e6634bf7b4fa83a6dd3112607547001"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:95f2a5796329323b8f0512e09dbb7a1860c46a39da62ecb2324f116fa8fdc85c"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c002b4ffc0be611f0d9da932eb0f704fe2602a9a949d1f738e4c34c75b0863d5"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a981a536974bbc7a512cf44ed14938cf01030a99e9b3a06dd59578882f06f985"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3287761bc4ee9e33561a7e058c72ac0938c4f57fe49a09eae428fd88aafe7bb6"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:42cb296636fcc8b0644486d15c12376cb9fa75443e00fb25de0b8602e64c1714"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0a55554a2fa0d408816b3b5cedf0045f4b8e1a6065aec45849de2d6f3f8e9786"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:c083af607d2515612056a31f0a8d9e0fcb5876b7bfc0abad3ecd275bc4ebc2d5"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:87d1351268731db79e0f8e745d92493ee2841c974128ef629dc518b937d9194c"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bd8f7df7d12c2db9fab40bdd87a7c09b1530128315d047a086fa3ae3435cb3a8"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:c180f51afb394e165eafe4ac2936a14bee3eb10debc9d9e4db8958fe36afe711"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:8c622a5fe39a48f78944a87d4fb8a53ee07344641b0562c540d840748571b811"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-win32.whl", hash = "sha256:db364eca23f876da6f9e16c9da0df51aa4f104a972735574842618b8c6d999d4"}, + {file = "charset_normalizer-3.3.2-cp37-cp37m-win_amd64.whl", hash = "sha256:86216b5cee4b06df986d214f664305142d9c76df9b6512be2738aa72a2048f99"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:6463effa3186ea09411d50efc7d85360b38d5f09b870c48e4600f63af490e56a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6c4caeef8fa63d06bd437cd4bdcf3ffefe6738fb1b25951440d80dc7df8c03ac"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:37e55c8e51c236f95b033f6fb391d7d7970ba5fe7ff453dad675e88cf303377a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fb69256e180cb6c8a894fee62b3afebae785babc1ee98b81cdf68bbca1987f33"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ae5f4161f18c61806f411a13b0310bea87f987c7d2ecdbdaad0e94eb2e404238"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b2b0a0c0517616b6869869f8c581d4eb2dd83a4d79e0ebcb7d373ef9956aeb0a"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:45485e01ff4d3630ec0d9617310448a8702f70e9c01906b0d0118bdf9d124cf2"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:eb00ed941194665c332bf8e078baf037d6c35d7c4f3102ea2d4f16ca94a26dc8"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:2127566c664442652f024c837091890cb1942c30937add288223dc895793f898"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:a50aebfa173e157099939b17f18600f72f84eed3049e743b68ad15bd69b6bf99"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4d0d1650369165a14e14e1e47b372cfcb31d6ab44e6e33cb2d4e57265290044d"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:923c0c831b7cfcb071580d3f46c4baf50f174be571576556269530f4bbd79d04"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:06a81e93cd441c56a9b65d8e1d043daeb97a3d0856d177d5c90ba85acb3db087"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-win32.whl", hash = "sha256:6ef1d82a3af9d3eecdba2321dc1b3c238245d890843e040e41e470ffa64c3e25"}, + {file = "charset_normalizer-3.3.2-cp38-cp38-win_amd64.whl", hash = "sha256:eb8821e09e916165e160797a6c17edda0679379a4be5c716c260e836e122f54b"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c235ebd9baae02f1b77bcea61bce332cb4331dc3617d254df3323aa01ab47bd4"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:5b4c145409bef602a690e7cfad0a15a55c13320ff7a3ad7ca59c13bb8ba4d45d"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:68d1f8a9e9e37c1223b656399be5d6b448dea850bed7d0f87a8311f1ff3dabb0"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:22afcb9f253dac0696b5a4be4a1c0f8762f8239e21b99680099abd9b2b1b2269"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:e27ad930a842b4c5eb8ac0016b0a54f5aebbe679340c26101df33424142c143c"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1f79682fbe303db92bc2b1136016a38a42e835d932bab5b3b1bfcfbf0640e519"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b261ccdec7821281dade748d088bb6e9b69e6d15b30652b74cbbac25e280b796"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:122c7fa62b130ed55f8f285bfd56d5f4b4a5b503609d181f9ad85e55c89f4185"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:d0eccceffcb53201b5bfebb52600a5fb483a20b61da9dbc885f8b103cbe7598c"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f96df6923e21816da7e0ad3fd47dd8f94b2a5ce594e00677c0013018b813458"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:7f04c839ed0b6b98b1a7501a002144b76c18fb1c1850c8b98d458ac269e26ed2"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:34d1c8da1e78d2e001f363791c98a272bb734000fcef47a491c1e3b0505657a8"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ff8fa367d09b717b2a17a052544193ad76cd49979c805768879cb63d9ca50561"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-win32.whl", hash = "sha256:aed38f6e4fb3f5d6bf81bfa990a07806be9d83cf7bacef998ab1a9bd660a581f"}, + {file = "charset_normalizer-3.3.2-cp39-cp39-win_amd64.whl", hash = "sha256:b01b88d45a6fcb69667cd6d2f7a9aeb4bf53760d7fc536bf679ec94fe9f3ff3d"}, + {file = "charset_normalizer-3.3.2-py3-none-any.whl", hash = "sha256:3e4d1f6587322d2788836a99c69062fbb091331ec940e02d12d179c1d53e25fc"}, ] [[package]] @@ -299,71 +347,83 @@ files = [ [[package]] name = "coverage" -version = "7.2.7" +version = "7.6.1" description = "Code coverage measurement for Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "coverage-7.2.7-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d39b5b4f2a66ccae8b7263ac3c8170994b65266797fb96cbbfd3fb5b23921db8"}, - {file = "coverage-7.2.7-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:6d040ef7c9859bb11dfeb056ff5b3872436e3b5e401817d87a31e1750b9ae2fb"}, - {file = "coverage-7.2.7-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ba90a9563ba44a72fda2e85302c3abc71c5589cea608ca16c22b9804262aaeb6"}, - {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e7d9405291c6928619403db1d10bd07888888ec1abcbd9748fdaa971d7d661b2"}, - {file = "coverage-7.2.7-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:31563e97dae5598556600466ad9beea39fb04e0229e61c12eaa206e0aa202063"}, - {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:ebba1cd308ef115925421d3e6a586e655ca5a77b5bf41e02eb0e4562a111f2d1"}, - {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:cb017fd1b2603ef59e374ba2063f593abe0fc45f2ad9abdde5b4d83bd922a353"}, - {file = "coverage-7.2.7-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d62a5c7dad11015c66fbb9d881bc4caa5b12f16292f857842d9d1871595f4495"}, - {file = "coverage-7.2.7-cp310-cp310-win32.whl", hash = "sha256:ee57190f24fba796e36bb6d3aa8a8783c643d8fa9760c89f7a98ab5455fbf818"}, - {file = "coverage-7.2.7-cp310-cp310-win_amd64.whl", hash = "sha256:f75f7168ab25dd93110c8a8117a22450c19976afbc44234cbf71481094c1b850"}, - {file = "coverage-7.2.7-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:06a9a2be0b5b576c3f18f1a241f0473575c4a26021b52b2a85263a00f034d51f"}, - {file = "coverage-7.2.7-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5baa06420f837184130752b7c5ea0808762083bf3487b5038d68b012e5937dbe"}, - {file = "coverage-7.2.7-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdec9e8cbf13a5bf63290fc6013d216a4c7232efb51548594ca3631a7f13c3a3"}, - {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:52edc1a60c0d34afa421c9c37078817b2e67a392cab17d97283b64c5833f427f"}, - {file = "coverage-7.2.7-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:63426706118b7f5cf6bb6c895dc215d8a418d5952544042c8a2d9fe87fcf09cb"}, - {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:afb17f84d56068a7c29f5fa37bfd38d5aba69e3304af08ee94da8ed5b0865833"}, - {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:48c19d2159d433ccc99e729ceae7d5293fbffa0bdb94952d3579983d1c8c9d97"}, - {file = "coverage-7.2.7-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:0e1f928eaf5469c11e886fe0885ad2bf1ec606434e79842a879277895a50942a"}, - {file = "coverage-7.2.7-cp311-cp311-win32.whl", hash = "sha256:33d6d3ea29d5b3a1a632b3c4e4f4ecae24ef170b0b9ee493883f2df10039959a"}, - {file = "coverage-7.2.7-cp311-cp311-win_amd64.whl", hash = "sha256:5b7540161790b2f28143191f5f8ec02fb132660ff175b7747b95dcb77ac26562"}, - {file = "coverage-7.2.7-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f2f67fe12b22cd130d34d0ef79206061bfb5eda52feb6ce0dba0644e20a03cf4"}, - {file = "coverage-7.2.7-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a342242fe22407f3c17f4b499276a02b01e80f861f1682ad1d95b04018e0c0d4"}, - {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:171717c7cb6b453aebac9a2ef603699da237f341b38eebfee9be75d27dc38e01"}, - {file = "coverage-7.2.7-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:49969a9f7ffa086d973d91cec8d2e31080436ef0fb4a359cae927e742abfaaa6"}, - {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:b46517c02ccd08092f4fa99f24c3b83d8f92f739b4657b0f146246a0ca6a831d"}, - {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:a3d33a6b3eae87ceaefa91ffdc130b5e8536182cd6dfdbfc1aa56b46ff8c86de"}, - {file = "coverage-7.2.7-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:976b9c42fb2a43ebf304fa7d4a310e5f16cc99992f33eced91ef6f908bd8f33d"}, - {file = "coverage-7.2.7-cp312-cp312-win32.whl", hash = "sha256:8de8bb0e5ad103888d65abef8bca41ab93721647590a3f740100cd65c3b00511"}, - {file = "coverage-7.2.7-cp312-cp312-win_amd64.whl", hash = "sha256:9e31cb64d7de6b6f09702bb27c02d1904b3aebfca610c12772452c4e6c21a0d3"}, - {file = "coverage-7.2.7-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:58c2ccc2f00ecb51253cbe5d8d7122a34590fac9646a960d1430d5b15321d95f"}, - {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d22656368f0e6189e24722214ed8d66b8022db19d182927b9a248a2a8a2f67eb"}, - {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a895fcc7b15c3fc72beb43cdcbdf0ddb7d2ebc959edac9cef390b0d14f39f8a9"}, - {file = "coverage-7.2.7-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e84606b74eb7de6ff581a7915e2dab7a28a0517fbe1c9239eb227e1354064dcd"}, - {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:0a5f9e1dbd7fbe30196578ca36f3fba75376fb99888c395c5880b355e2875f8a"}, - {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:419bfd2caae268623dd469eff96d510a920c90928b60f2073d79f8fe2bbc5959"}, - {file = "coverage-7.2.7-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:2aee274c46590717f38ae5e4650988d1af340fe06167546cc32fe2f58ed05b02"}, - {file = "coverage-7.2.7-cp37-cp37m-win32.whl", hash = "sha256:61b9a528fb348373c433e8966535074b802c7a5d7f23c4f421e6c6e2f1697a6f"}, - {file = "coverage-7.2.7-cp37-cp37m-win_amd64.whl", hash = "sha256:b1c546aca0ca4d028901d825015dc8e4d56aac4b541877690eb76490f1dc8ed0"}, - {file = "coverage-7.2.7-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:54b896376ab563bd38453cecb813c295cf347cf5906e8b41d340b0321a5433e5"}, - {file = "coverage-7.2.7-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:3d376df58cc111dc8e21e3b6e24606b5bb5dee6024f46a5abca99124b2229ef5"}, - {file = "coverage-7.2.7-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5e330fc79bd7207e46c7d7fd2bb4af2963f5f635703925543a70b99574b0fea9"}, - {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1e9d683426464e4a252bf70c3498756055016f99ddaec3774bf368e76bbe02b6"}, - {file = "coverage-7.2.7-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8d13c64ee2d33eccf7437961b6ea7ad8673e2be040b4f7fd4fd4d4d28d9ccb1e"}, - {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:b7aa5f8a41217360e600da646004f878250a0d6738bcdc11a0a39928d7dc2050"}, - {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:8fa03bce9bfbeeef9f3b160a8bed39a221d82308b4152b27d82d8daa7041fee5"}, - {file = "coverage-7.2.7-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:245167dd26180ab4c91d5e1496a30be4cd721a5cf2abf52974f965f10f11419f"}, - {file = "coverage-7.2.7-cp38-cp38-win32.whl", hash = "sha256:d2c2db7fd82e9b72937969bceac4d6ca89660db0a0967614ce2481e81a0b771e"}, - {file = "coverage-7.2.7-cp38-cp38-win_amd64.whl", hash = "sha256:2e07b54284e381531c87f785f613b833569c14ecacdcb85d56b25c4622c16c3c"}, - {file = "coverage-7.2.7-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:537891ae8ce59ef63d0123f7ac9e2ae0fc8b72c7ccbe5296fec45fd68967b6c9"}, - {file = "coverage-7.2.7-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:06fb182e69f33f6cd1d39a6c597294cff3143554b64b9825d1dc69d18cc2fff2"}, - {file = "coverage-7.2.7-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:201e7389591af40950a6480bd9edfa8ed04346ff80002cec1a66cac4549c1ad7"}, - {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f6951407391b639504e3b3be51b7ba5f3528adbf1a8ac3302b687ecababf929e"}, - {file = "coverage-7.2.7-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6f48351d66575f535669306aa7d6d6f71bc43372473b54a832222803eb956fd1"}, - {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:b29019c76039dc3c0fd815c41392a044ce555d9bcdd38b0fb60fb4cd8e475ba9"}, - {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:81c13a1fc7468c40f13420732805a4c38a105d89848b7c10af65a90beff25250"}, - {file = "coverage-7.2.7-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:975d70ab7e3c80a3fe86001d8751f6778905ec723f5b110aed1e450da9d4b7f2"}, - {file = "coverage-7.2.7-cp39-cp39-win32.whl", hash = "sha256:7ee7d9d4822c8acc74a5e26c50604dff824710bc8de424904c0982e25c39c6cb"}, - {file = "coverage-7.2.7-cp39-cp39-win_amd64.whl", hash = "sha256:eb393e5ebc85245347950143969b241d08b52b88a3dc39479822e073a1a8eb27"}, - {file = "coverage-7.2.7-pp37.pp38.pp39-none-any.whl", hash = "sha256:b7b4c971f05e6ae490fef852c218b0e79d4e52f79ef0c8475566584a8fb3e01d"}, - {file = "coverage-7.2.7.tar.gz", hash = "sha256:924d94291ca674905fe9481f12294eb11f2d3d3fd1adb20314ba89e94f44ed59"}, + {file = "coverage-7.6.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:b06079abebbc0e89e6163b8e8f0e16270124c154dc6e4a47b413dd538859af16"}, + {file = "coverage-7.6.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:cf4b19715bccd7ee27b6b120e7e9dd56037b9c0681dcc1adc9ba9db3d417fa36"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e61c0abb4c85b095a784ef23fdd4aede7a2628478e7baba7c5e3deba61070a02"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:fd21f6ae3f08b41004dfb433fa895d858f3f5979e7762d052b12aef444e29afc"}, + {file = "coverage-7.6.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8f59d57baca39b32db42b83b2a7ba6f47ad9c394ec2076b084c3f029b7afca23"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:a1ac0ae2b8bd743b88ed0502544847c3053d7171a3cff9228af618a068ed9c34"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e6a08c0be454c3b3beb105c0596ebdc2371fab6bb90c0c0297f4e58fd7e1012c"}, + {file = "coverage-7.6.1-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:f5796e664fe802da4f57a168c85359a8fbf3eab5e55cd4e4569fbacecc903959"}, + {file = "coverage-7.6.1-cp310-cp310-win32.whl", hash = "sha256:7bb65125fcbef8d989fa1dd0e8a060999497629ca5b0efbca209588a73356232"}, + {file = "coverage-7.6.1-cp310-cp310-win_amd64.whl", hash = "sha256:3115a95daa9bdba70aea750db7b96b37259a81a709223c8448fa97727d546fe0"}, + {file = "coverage-7.6.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7dea0889685db8550f839fa202744652e87c60015029ce3f60e006f8c4462c93"}, + {file = "coverage-7.6.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ed37bd3c3b063412f7620464a9ac1314d33100329f39799255fb8d3027da50d3"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d85f5e9a5f8b73e2350097c3756ef7e785f55bd71205defa0bfdaf96c31616ff"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9bc572be474cafb617672c43fe989d6e48d3c83af02ce8de73fff1c6bb3c198d"}, + {file = "coverage-7.6.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c0420b573964c760df9e9e86d1a9a622d0d27f417e1a949a8a66dd7bcee7bc6"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:1f4aa8219db826ce6be7099d559f8ec311549bfc4046f7f9fe9b5cea5c581c56"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:fc5a77d0c516700ebad189b587de289a20a78324bc54baee03dd486f0855d234"}, + {file = "coverage-7.6.1-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:b48f312cca9621272ae49008c7f613337c53fadca647d6384cc129d2996d1133"}, + {file = "coverage-7.6.1-cp311-cp311-win32.whl", hash = "sha256:1125ca0e5fd475cbbba3bb67ae20bd2c23a98fac4e32412883f9bcbaa81c314c"}, + {file = "coverage-7.6.1-cp311-cp311-win_amd64.whl", hash = "sha256:8ae539519c4c040c5ffd0632784e21b2f03fc1340752af711f33e5be83a9d6c6"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:95cae0efeb032af8458fc27d191f85d1717b1d4e49f7cb226cf526ff28179778"}, + {file = "coverage-7.6.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:5621a9175cf9d0b0c84c2ef2b12e9f5f5071357c4d2ea6ca1cf01814f45d2391"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:260933720fdcd75340e7dbe9060655aff3af1f0c5d20f46b57f262ab6c86a5e8"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07e2ca0ad381b91350c0ed49d52699b625aab2b44b65e1b4e02fa9df0e92ad2d"}, + {file = "coverage-7.6.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c44fee9975f04b33331cb8eb272827111efc8930cfd582e0320613263ca849ca"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:877abb17e6339d96bf08e7a622d05095e72b71f8afd8a9fefc82cf30ed944163"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:3e0cadcf6733c09154b461f1ca72d5416635e5e4ec4e536192180d34ec160f8a"}, + {file = "coverage-7.6.1-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:c3c02d12f837d9683e5ab2f3d9844dc57655b92c74e286c262e0fc54213c216d"}, + {file = "coverage-7.6.1-cp312-cp312-win32.whl", hash = "sha256:e05882b70b87a18d937ca6768ff33cc3f72847cbc4de4491c8e73880766718e5"}, + {file = "coverage-7.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:b5d7b556859dd85f3a541db6a4e0167b86e7273e1cdc973e5b175166bb634fdb"}, + {file = "coverage-7.6.1-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:a4acd025ecc06185ba2b801f2de85546e0b8ac787cf9d3b06e7e2a69f925b106"}, + {file = "coverage-7.6.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:a6d3adcf24b624a7b778533480e32434a39ad8fa30c315208f6d3e5542aeb6e9"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d0c212c49b6c10e6951362f7c6df3329f04c2b1c28499563d4035d964ab8e08c"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e81d7a3e58882450ec4186ca59a3f20a5d4440f25b1cff6f0902ad890e6748a"}, + {file = "coverage-7.6.1-cp313-cp313-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:78b260de9790fd81e69401c2dc8b17da47c8038176a79092a89cb2b7d945d060"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:a78d169acd38300060b28d600344a803628c3fd585c912cacc9ea8790fe96862"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:2c09f4ce52cb99dd7505cd0fc8e0e37c77b87f46bc9c1eb03fe3bc9991085388"}, + {file = "coverage-7.6.1-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:6878ef48d4227aace338d88c48738a4258213cd7b74fd9a3d4d7582bb1d8a155"}, + {file = "coverage-7.6.1-cp313-cp313-win32.whl", hash = "sha256:44df346d5215a8c0e360307d46ffaabe0f5d3502c8a1cefd700b34baf31d411a"}, + {file = "coverage-7.6.1-cp313-cp313-win_amd64.whl", hash = "sha256:8284cf8c0dd272a247bc154eb6c95548722dce90d098c17a883ed36e67cdb129"}, + {file = "coverage-7.6.1-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:d3296782ca4eab572a1a4eca686d8bfb00226300dcefdf43faa25b5242ab8a3e"}, + {file = "coverage-7.6.1-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:502753043567491d3ff6d08629270127e0c31d4184c4c8d98f92c26f65019962"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a89ecca80709d4076b95f89f308544ec8f7b4727e8a547913a35f16717856cb"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a318d68e92e80af8b00fa99609796fdbcdfef3629c77c6283566c6f02c6d6704"}, + {file = "coverage-7.6.1-cp313-cp313t-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:13b0a73a0896988f053e4fbb7de6d93388e6dd292b0d87ee51d106f2c11b465b"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:4421712dbfc5562150f7554f13dde997a2e932a6b5f352edcce948a815efee6f"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_i686.whl", hash = "sha256:166811d20dfea725e2e4baa71fffd6c968a958577848d2131f39b60043400223"}, + {file = "coverage-7.6.1-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:225667980479a17db1048cb2bf8bfb39b8e5be8f164b8f6628b64f78a72cf9d3"}, + {file = "coverage-7.6.1-cp313-cp313t-win32.whl", hash = "sha256:170d444ab405852903b7d04ea9ae9b98f98ab6d7e63e1115e82620807519797f"}, + {file = "coverage-7.6.1-cp313-cp313t-win_amd64.whl", hash = "sha256:b9f222de8cded79c49bf184bdbc06630d4c58eec9459b939b4a690c82ed05657"}, + {file = "coverage-7.6.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6db04803b6c7291985a761004e9060b2bca08da6d04f26a7f2294b8623a0c1a0"}, + {file = "coverage-7.6.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:f1adfc8ac319e1a348af294106bc6a8458a0f1633cc62a1446aebc30c5fa186a"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a95324a9de9650a729239daea117df21f4b9868ce32e63f8b650ebe6cef5595b"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b43c03669dc4618ec25270b06ecd3ee4fa94c7f9b3c14bae6571ca00ef98b0d3"}, + {file = "coverage-7.6.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8929543a7192c13d177b770008bc4e8119f2e1f881d563fc6b6305d2d0ebe9de"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:a09ece4a69cf399510c8ab25e0950d9cf2b42f7b3cb0374f95d2e2ff594478a6"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:9054a0754de38d9dbd01a46621636689124d666bad1936d76c0341f7d71bf569"}, + {file = "coverage-7.6.1-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:0dbde0f4aa9a16fa4d754356a8f2e36296ff4d83994b2c9d8398aa32f222f989"}, + {file = "coverage-7.6.1-cp38-cp38-win32.whl", hash = "sha256:da511e6ad4f7323ee5702e6633085fb76c2f893aaf8ce4c51a0ba4fc07580ea7"}, + {file = "coverage-7.6.1-cp38-cp38-win_amd64.whl", hash = "sha256:3f1156e3e8f2872197af3840d8ad307a9dd18e615dc64d9ee41696f287c57ad8"}, + {file = "coverage-7.6.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:abd5fd0db5f4dc9289408aaf34908072f805ff7792632250dcb36dc591d24255"}, + {file = "coverage-7.6.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:547f45fa1a93154bd82050a7f3cddbc1a7a4dd2a9bf5cb7d06f4ae29fe94eaf8"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:645786266c8f18a931b65bfcefdbf6952dd0dea98feee39bd188607a9d307ed2"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9e0b2df163b8ed01d515807af24f63de04bebcecbd6c3bfeff88385789fdf75a"}, + {file = "coverage-7.6.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:609b06f178fe8e9f89ef676532760ec0b4deea15e9969bf754b37f7c40326dbc"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:702855feff378050ae4f741045e19a32d57d19f3e0676d589df0575008ea5004"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:2bdb062ea438f22d99cba0d7829c2ef0af1d768d1e4a4f528087224c90b132cb"}, + {file = "coverage-7.6.1-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:9c56863d44bd1c4fe2abb8a4d6f5371d197f1ac0ebdee542f07f35895fc07f36"}, + {file = "coverage-7.6.1-cp39-cp39-win32.whl", hash = "sha256:6e2cd258d7d927d09493c8df1ce9174ad01b381d4729a9d8d4e38670ca24774c"}, + {file = "coverage-7.6.1-cp39-cp39-win_amd64.whl", hash = "sha256:06a737c882bd26d0d6ee7269b20b12f14a8704807a01056c80bb881a4b2ce6ca"}, + {file = "coverage-7.6.1-pp38.pp39.pp310-none-any.whl", hash = "sha256:e9a6e0eb86070e8ccaedfbd9d38fec54864f3125ab95419970575b42af7541df"}, + {file = "coverage-7.6.1.tar.gz", hash = "sha256:953510dfb7b12ab69d20135a0662397f077c59b1e6379a768e97c59d852ee51d"}, ] [package.dependencies] @@ -383,134 +443,199 @@ files = [ {file = "decorator-5.1.1.tar.gz", hash = "sha256:637996211036b6385ef91435e4fae22989472f9d571faba8927ba8253acbc330"}, ] +[[package]] +name = "distro" +version = "1.9.0" +description = "Distro - an OS platform information API" +optional = false +python-versions = ">=3.6" +files = [ + {file = "distro-1.9.0-py3-none-any.whl", hash = "sha256:7bffd925d65168f85027d8da9af6bddab658135b840670a223589bc0c8ef02b2"}, + {file = "distro-1.9.0.tar.gz", hash = "sha256:2fa77c6fd8940f116ee1d6b94a2f90b13b5ea8d019b98bc8bafdcabcdd9bdbed"}, +] + +[[package]] +name = "exceptiongroup" +version = "1.2.2" +description = "Backport of PEP 654 (exception groups)" +optional = false +python-versions = ">=3.7" +files = [ + {file = "exceptiongroup-1.2.2-py3-none-any.whl", hash = "sha256:3111b9d131c238bec2f8f516e123e14ba243563fb135d3fe885990585aa7795b"}, + {file = "exceptiongroup-1.2.2.tar.gz", hash = "sha256:47c2edf7c6738fafb49fd34290706d1a1a2f4d1c6df275526b62cbb4aa5393cc"}, +] + +[package.extras] +test = ["pytest (>=6)"] + [[package]] name = "filelock" -version = "3.12.2" +version = "3.15.4" description = "A platform independent file lock." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "filelock-3.12.2-py3-none-any.whl", hash = "sha256:cbb791cdea2a72f23da6ac5b5269ab0a0d161e9ef0100e653b69049a7706d1ec"}, - {file = "filelock-3.12.2.tar.gz", hash = "sha256:002740518d8aa59a26b0c76e10fb8c6e15eae825d34b6fdf670333fd7b938d81"}, + {file = "filelock-3.15.4-py3-none-any.whl", hash = "sha256:6ca1fffae96225dab4c6eaf1c4f4f28cd2568d3ec2a44e15a08520504de468e7"}, + {file = "filelock-3.15.4.tar.gz", hash = "sha256:2207938cbc1844345cb01a5a95524dae30f0ce089eba5b00378295a17e3e90cb"}, ] [package.extras] -docs = ["furo (>=2023.5.20)", "sphinx (>=7.0.1)", "sphinx-autodoc-typehints (>=1.23,!=1.23.4)"] -testing = ["covdefaults (>=2.3)", "coverage (>=7.2.7)", "diff-cover (>=7.5)", "pytest (>=7.3.1)", "pytest-cov (>=4.1)", "pytest-mock (>=3.10)", "pytest-timeout (>=2.1)"] +docs = ["furo (>=2023.9.10)", "sphinx (>=7.2.6)", "sphinx-autodoc-typehints (>=1.25.2)"] +testing = ["covdefaults (>=2.3)", "coverage (>=7.3.2)", "diff-cover (>=8.0.1)", "pytest (>=7.4.3)", "pytest-asyncio (>=0.21)", "pytest-cov (>=4.1)", "pytest-mock (>=3.12)", "pytest-timeout (>=2.2)", "virtualenv (>=20.26.2)"] +typing = ["typing-extensions (>=4.8)"] [[package]] name = "frozenlist" -version = "1.3.3" +version = "1.4.1" description = "A list-like structure which implements collections.abc.MutableSequence" optional = false +python-versions = ">=3.8" +files = [ + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:f9aa1878d1083b276b0196f2dfbe00c9b7e752475ed3b682025ff20c1c1f51ac"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:29acab3f66f0f24674b7dc4736477bcd4bc3ad4b896f5f45379a67bce8b96868"}, + {file = "frozenlist-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:74fb4bee6880b529a0c6560885fce4dc95936920f9f20f53d99a213f7bf66776"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:590344787a90ae57d62511dd7c736ed56b428f04cd8c161fcc5e7232c130c69a"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:068b63f23b17df8569b7fdca5517edef76171cf3897eb68beb01341131fbd2ad"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c849d495bf5154cd8da18a9eb15db127d4dba2968d88831aff6f0331ea9bd4c"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:9750cc7fe1ae3b1611bb8cfc3f9ec11d532244235d75901fb6b8e42ce9229dfe"}, + {file = "frozenlist-1.4.1-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a9b2de4cf0cdd5bd2dee4c4f63a653c61d2408055ab77b151c1957f221cabf2a"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:0633c8d5337cb5c77acbccc6357ac49a1770b8c487e5b3505c57b949b4b82e98"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:27657df69e8801be6c3638054e202a135c7f299267f1a55ed3a598934f6c0d75"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:f9a3ea26252bd92f570600098783d1371354d89d5f6b7dfd87359d669f2109b5"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:4f57dab5fe3407b6c0c1cc907ac98e8a189f9e418f3b6e54d65a718aaafe3950"}, + {file = "frozenlist-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:e02a0e11cf6597299b9f3bbd3f93d79217cb90cfd1411aec33848b13f5c656cc"}, + {file = "frozenlist-1.4.1-cp310-cp310-win32.whl", hash = "sha256:a828c57f00f729620a442881cc60e57cfcec6842ba38e1b19fd3e47ac0ff8dc1"}, + {file = "frozenlist-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:f56e2333dda1fe0f909e7cc59f021eba0d2307bc6f012a1ccf2beca6ba362439"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a0cb6f11204443f27a1628b0e460f37fb30f624be6051d490fa7d7e26d4af3d0"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b46c8ae3a8f1f41a0d2ef350c0b6e65822d80772fe46b653ab6b6274f61d4a49"}, + {file = "frozenlist-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:fde5bd59ab5357e3853313127f4d3565fc7dad314a74d7b5d43c22c6a5ed2ced"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:722e1124aec435320ae01ee3ac7bec11a5d47f25d0ed6328f2273d287bc3abb0"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2471c201b70d58a0f0c1f91261542a03d9a5e088ed3dc6c160d614c01649c106"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c757a9dd70d72b076d6f68efdbb9bc943665ae954dad2801b874c8c69e185068"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f146e0911cb2f1da549fc58fc7bcd2b836a44b79ef871980d605ec392ff6b0d2"}, + {file = "frozenlist-1.4.1-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4f9c515e7914626b2a2e1e311794b4c35720a0be87af52b79ff8e1429fc25f19"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:c302220494f5c1ebeb0912ea782bcd5e2f8308037b3c7553fad0e48ebad6ad82"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:442acde1e068288a4ba7acfe05f5f343e19fac87bfc96d89eb886b0363e977ec"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:1b280e6507ea8a4fa0c0a7150b4e526a8d113989e28eaaef946cc77ffd7efc0a"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:fe1a06da377e3a1062ae5fe0926e12b84eceb8a50b350ddca72dc85015873f74"}, + {file = "frozenlist-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:db9e724bebd621d9beca794f2a4ff1d26eed5965b004a97f1f1685a173b869c2"}, + {file = "frozenlist-1.4.1-cp311-cp311-win32.whl", hash = "sha256:e774d53b1a477a67838a904131c4b0eef6b3d8a651f8b138b04f748fccfefe17"}, + {file = "frozenlist-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:fb3c2db03683b5767dedb5769b8a40ebb47d6f7f45b1b3e3b4b51ec8ad9d9825"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:1979bc0aeb89b33b588c51c54ab0161791149f2461ea7c7c946d95d5f93b56ae"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cc7b01b3754ea68a62bd77ce6020afaffb44a590c2289089289363472d13aedb"}, + {file = "frozenlist-1.4.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:c9c92be9fd329ac801cc420e08452b70e7aeab94ea4233a4804f0915c14eba9b"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c3894db91f5a489fc8fa6a9991820f368f0b3cbdb9cd8849547ccfab3392d86"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ba60bb19387e13597fb059f32cd4d59445d7b18b69a745b8f8e5db0346f33480"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8aefbba5f69d42246543407ed2461db31006b0f76c4e32dfd6f42215a2c41d09"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:780d3a35680ced9ce682fbcf4cb9c2bad3136eeff760ab33707b71db84664e3a"}, + {file = "frozenlist-1.4.1-cp312-cp312-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9acbb16f06fe7f52f441bb6f413ebae6c37baa6ef9edd49cdd567216da8600cd"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:23b701e65c7b36e4bf15546a89279bd4d8675faabc287d06bbcfac7d3c33e1e6"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:3e0153a805a98f5ada7e09826255ba99fb4f7524bb81bf6b47fb702666484ae1"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:dd9b1baec094d91bf36ec729445f7769d0d0cf6b64d04d86e45baf89e2b9059b"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:1a4471094e146b6790f61b98616ab8e44f72661879cc63fa1049d13ef711e71e"}, + {file = "frozenlist-1.4.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:5667ed53d68d91920defdf4035d1cdaa3c3121dc0b113255124bcfada1cfa1b8"}, + {file = "frozenlist-1.4.1-cp312-cp312-win32.whl", hash = "sha256:beee944ae828747fd7cb216a70f120767fc9f4f00bacae8543c14a6831673f89"}, + {file = "frozenlist-1.4.1-cp312-cp312-win_amd64.whl", hash = "sha256:64536573d0a2cb6e625cf309984e2d873979709f2cf22839bf2d61790b448ad5"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:20b51fa3f588ff2fe658663db52a41a4f7aa6c04f6201449c6c7c476bd255c0d"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:410478a0c562d1a5bcc2f7ea448359fcb050ed48b3c6f6f4f18c313a9bdb1826"}, + {file = "frozenlist-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:c6321c9efe29975232da3bd0af0ad216800a47e93d763ce64f291917a381b8eb"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:48f6a4533887e189dae092f1cf981f2e3885175f7a0f33c91fb5b7b682b6bab6"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6eb73fa5426ea69ee0e012fb59cdc76a15b1283d6e32e4f8dc4482ec67d1194d"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:fbeb989b5cc29e8daf7f976b421c220f1b8c731cbf22b9130d8815418ea45887"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:32453c1de775c889eb4e22f1197fe3bdfe457d16476ea407472b9442e6295f7a"}, + {file = "frozenlist-1.4.1-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:693945278a31f2086d9bf3df0fe8254bbeaef1fe71e1351c3bd730aa7d31c41b"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:1d0ce09d36d53bbbe566fe296965b23b961764c0bcf3ce2fa45f463745c04701"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:3a670dc61eb0d0eb7080890c13de3066790f9049b47b0de04007090807c776b0"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:dca69045298ce5c11fd539682cff879cc1e664c245d1c64da929813e54241d11"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a06339f38e9ed3a64e4c4e43aec7f59084033647f908e4259d279a52d3757d09"}, + {file = "frozenlist-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b7f2f9f912dca3934c1baec2e4585a674ef16fe00218d833856408c48d5beee7"}, + {file = "frozenlist-1.4.1-cp38-cp38-win32.whl", hash = "sha256:e7004be74cbb7d9f34553a5ce5fb08be14fb33bc86f332fb71cbe5216362a497"}, + {file = "frozenlist-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:5a7d70357e7cee13f470c7883a063aae5fe209a493c57d86eb7f5a6f910fae09"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:bfa4a17e17ce9abf47a74ae02f32d014c5e9404b6d9ac7f729e01562bbee601e"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:b7e3ed87d4138356775346e6845cccbe66cd9e207f3cd11d2f0b9fd13681359d"}, + {file = "frozenlist-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:c99169d4ff810155ca50b4da3b075cbde79752443117d89429595c2e8e37fed8"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:edb678da49d9f72c9f6c609fbe41a5dfb9a9282f9e6a2253d5a91e0fc382d7c0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:6db4667b187a6742b33afbbaf05a7bc551ffcf1ced0000a571aedbb4aa42fc7b"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55fdc093b5a3cb41d420884cdaf37a1e74c3c37a31f46e66286d9145d2063bd0"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:82e8211d69a4f4bc360ea22cd6555f8e61a1bd211d1d5d39d3d228b48c83a897"}, + {file = "frozenlist-1.4.1-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:89aa2c2eeb20957be2d950b85974b30a01a762f3308cd02bb15e1ad632e22dc7"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9d3e0c25a2350080e9319724dede4f31f43a6c9779be48021a7f4ebde8b2d742"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:7268252af60904bf52c26173cbadc3a071cece75f873705419c8681f24d3edea"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:0c250a29735d4f15321007fb02865f0e6b6a41a6b88f1f523ca1596ab5f50bd5"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:96ec70beabbd3b10e8bfe52616a13561e58fe84c0101dd031dc78f250d5128b9"}, + {file = "frozenlist-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:23b2d7679b73fe0e5a4560b672a39f98dfc6f60df63823b0a9970525325b95f6"}, + {file = "frozenlist-1.4.1-cp39-cp39-win32.whl", hash = "sha256:a7496bfe1da7fb1a4e1cc23bb67c58fab69311cc7d32b5a99c2007b4b2a0e932"}, + {file = "frozenlist-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e6a20a581f9ce92d389a8c7d7c3dd47c81fd5d6e655c8dddf341e14aa48659d0"}, + {file = "frozenlist-1.4.1-py3-none-any.whl", hash = "sha256:04ced3e6a46b4cfffe20f9ae482818e34eba9b5fb0ce4056e4cc9b6e212d09b7"}, + {file = "frozenlist-1.4.1.tar.gz", hash = "sha256:c037a86e8513059a2613aaba4d817bb90b9d9b6b69aace3ce9c877e8c8ed402b"}, +] + +[[package]] +name = "h11" +version = "0.14.0" +description = "A pure-Python, bring-your-own-I/O implementation of HTTP/1.1" +optional = false python-versions = ">=3.7" files = [ - {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:ff8bf625fe85e119553b5383ba0fb6aa3d0ec2ae980295aaefa552374926b3f4"}, - {file = "frozenlist-1.3.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:dfbac4c2dfcc082fcf8d942d1e49b6aa0766c19d3358bd86e2000bf0fa4a9cf0"}, - {file = "frozenlist-1.3.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b1c63e8d377d039ac769cd0926558bb7068a1f7abb0f003e3717ee003ad85530"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7fdfc24dcfce5b48109867c13b4cb15e4660e7bd7661741a391f821f23dfdca7"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2c926450857408e42f0bbc295e84395722ce74bae69a3b2aa2a65fe22cb14b99"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:1841e200fdafc3d51f974d9d377c079a0694a8f06de2e67b48150328d66d5483"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f470c92737afa7d4c3aacc001e335062d582053d4dbe73cda126f2d7031068dd"}, - {file = "frozenlist-1.3.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:783263a4eaad7c49983fe4b2e7b53fa9770c136c270d2d4bbb6d2192bf4d9caf"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:924620eef691990dfb56dc4709f280f40baee568c794b5c1885800c3ecc69816"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ae4dc05c465a08a866b7a1baf360747078b362e6a6dbeb0c57f234db0ef88ae0"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:bed331fe18f58d844d39ceb398b77d6ac0b010d571cba8267c2e7165806b00ce"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:02c9ac843e3390826a265e331105efeab489ffaf4dd86384595ee8ce6d35ae7f"}, - {file = "frozenlist-1.3.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:9545a33965d0d377b0bc823dcabf26980e77f1b6a7caa368a365a9497fb09420"}, - {file = "frozenlist-1.3.3-cp310-cp310-win32.whl", hash = "sha256:d5cd3ab21acbdb414bb6c31958d7b06b85eeb40f66463c264a9b343a4e238642"}, - {file = "frozenlist-1.3.3-cp310-cp310-win_amd64.whl", hash = "sha256:b756072364347cb6aa5b60f9bc18e94b2f79632de3b0190253ad770c5df17db1"}, - {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:b4395e2f8d83fbe0c627b2b696acce67868793d7d9750e90e39592b3626691b7"}, - {file = "frozenlist-1.3.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:14143ae966a6229350021384870458e4777d1eae4c28d1a7aa47f24d030e6678"}, - {file = "frozenlist-1.3.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5d8860749e813a6f65bad8285a0520607c9500caa23fea6ee407e63debcdbef6"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23d16d9f477bb55b6154654e0e74557040575d9d19fe78a161bd33d7d76808e8"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb82dbba47a8318e75f679690190c10a5e1f447fbf9df41cbc4c3afd726d88cb"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9309869032abb23d196cb4e4db574232abe8b8be1339026f489eeb34a4acfd91"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a97b4fe50b5890d36300820abd305694cb865ddb7885049587a5678215782a6b"}, - {file = "frozenlist-1.3.3-cp311-cp311-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c188512b43542b1e91cadc3c6c915a82a5eb95929134faf7fd109f14f9892ce4"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:303e04d422e9b911a09ad499b0368dc551e8c3cd15293c99160c7f1f07b59a48"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:0771aed7f596c7d73444c847a1c16288937ef988dc04fb9f7be4b2aa91db609d"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:66080ec69883597e4d026f2f71a231a1ee9887835902dbe6b6467d5a89216cf6"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:41fe21dc74ad3a779c3d73a2786bdf622ea81234bdd4faf90b8b03cad0c2c0b4"}, - {file = "frozenlist-1.3.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f20380df709d91525e4bee04746ba612a4df0972c1b8f8e1e8af997e678c7b81"}, - {file = "frozenlist-1.3.3-cp311-cp311-win32.whl", hash = "sha256:f30f1928162e189091cf4d9da2eac617bfe78ef907a761614ff577ef4edfb3c8"}, - {file = "frozenlist-1.3.3-cp311-cp311-win_amd64.whl", hash = "sha256:a6394d7dadd3cfe3f4b3b186e54d5d8504d44f2d58dcc89d693698e8b7132b32"}, - {file = "frozenlist-1.3.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8df3de3a9ab8325f94f646609a66cbeeede263910c5c0de0101079ad541af332"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0693c609e9742c66ba4870bcee1ad5ff35462d5ffec18710b4ac89337ff16e27"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:cd4210baef299717db0a600d7a3cac81d46ef0e007f88c9335db79f8979c0d3d"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:394c9c242113bfb4b9aa36e2b80a05ffa163a30691c7b5a29eba82e937895d5e"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6327eb8e419f7d9c38f333cde41b9ae348bec26d840927332f17e887a8dcb70d"}, - {file = "frozenlist-1.3.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2e24900aa13212e75e5b366cb9065e78bbf3893d4baab6052d1aca10d46d944c"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:3843f84a6c465a36559161e6c59dce2f2ac10943040c2fd021cfb70d58c4ad56"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:84610c1502b2461255b4c9b7d5e9c48052601a8957cd0aea6ec7a7a1e1fb9420"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:c21b9aa40e08e4f63a2f92ff3748e6b6c84d717d033c7b3438dd3123ee18f70e"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:efce6ae830831ab6a22b9b4091d411698145cb9b8fc869e1397ccf4b4b6455cb"}, - {file = "frozenlist-1.3.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:40de71985e9042ca00b7953c4f41eabc3dc514a2d1ff534027f091bc74416401"}, - {file = "frozenlist-1.3.3-cp37-cp37m-win32.whl", hash = "sha256:180c00c66bde6146a860cbb81b54ee0df350d2daf13ca85b275123bbf85de18a"}, - {file = "frozenlist-1.3.3-cp37-cp37m-win_amd64.whl", hash = "sha256:9bbbcedd75acdfecf2159663b87f1bb5cfc80e7cd99f7ddd9d66eb98b14a8411"}, - {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:034a5c08d36649591be1cbb10e09da9f531034acfe29275fc5454a3b101ce41a"}, - {file = "frozenlist-1.3.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:ba64dc2b3b7b158c6660d49cdb1d872d1d0bf4e42043ad8d5006099479a194e5"}, - {file = "frozenlist-1.3.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:47df36a9fe24054b950bbc2db630d508cca3aa27ed0566c0baf661225e52c18e"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:008a054b75d77c995ea26629ab3a0c0d7281341f2fa7e1e85fa6153ae29ae99c"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:841ea19b43d438a80b4de62ac6ab21cfe6827bb8a9dc62b896acc88eaf9cecba"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e235688f42b36be2b6b06fc37ac2126a73b75fb8d6bc66dd632aa35286238703"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ca713d4af15bae6e5d79b15c10c8522859a9a89d3b361a50b817c98c2fb402a2"}, - {file = "frozenlist-1.3.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9ac5995f2b408017b0be26d4a1d7c61bce106ff3d9e3324374d66b5964325448"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:a4ae8135b11652b08a8baf07631d3ebfe65a4c87909dbef5fa0cdde440444ee4"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4ea42116ceb6bb16dbb7d526e242cb6747b08b7710d9782aa3d6732bd8d27649"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:810860bb4bdce7557bc0febb84bbd88198b9dbc2022d8eebe5b3590b2ad6c842"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:ee78feb9d293c323b59a6f2dd441b63339a30edf35abcb51187d2fc26e696d13"}, - {file = "frozenlist-1.3.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:0af2e7c87d35b38732e810befb9d797a99279cbb85374d42ea61c1e9d23094b3"}, - {file = "frozenlist-1.3.3-cp38-cp38-win32.whl", hash = "sha256:899c5e1928eec13fd6f6d8dc51be23f0d09c5281e40d9cf4273d188d9feeaf9b"}, - {file = "frozenlist-1.3.3-cp38-cp38-win_amd64.whl", hash = "sha256:7f44e24fa70f6fbc74aeec3e971f60a14dde85da364aa87f15d1be94ae75aeef"}, - {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:2b07ae0c1edaa0a36339ec6cce700f51b14a3fc6545fdd32930d2c83917332cf"}, - {file = "frozenlist-1.3.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ebb86518203e12e96af765ee89034a1dbb0c3c65052d1b0c19bbbd6af8a145e1"}, - {file = "frozenlist-1.3.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:5cf820485f1b4c91e0417ea0afd41ce5cf5965011b3c22c400f6d144296ccbc0"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5c11e43016b9024240212d2a65043b70ed8dfd3b52678a1271972702d990ac6d"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8fa3c6e3305aa1146b59a09b32b2e04074945ffcfb2f0931836d103a2c38f936"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:352bd4c8c72d508778cf05ab491f6ef36149f4d0cb3c56b1b4302852255d05d5"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:65a5e4d3aa679610ac6e3569e865425b23b372277f89b5ef06cf2cdaf1ebf22b"}, - {file = "frozenlist-1.3.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e2c1185858d7e10ff045c496bbf90ae752c28b365fef2c09cf0fa309291669"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f163d2fd041c630fed01bc48d28c3ed4a3b003c00acd396900e11ee5316b56bb"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:05cdb16d09a0832eedf770cb7bd1fe57d8cf4eaf5aced29c4e41e3f20b30a784"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:8bae29d60768bfa8fb92244b74502b18fae55a80eac13c88eb0b496d4268fd2d"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:eedab4c310c0299961ac285591acd53dc6723a1ebd90a57207c71f6e0c2153ab"}, - {file = "frozenlist-1.3.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:3bbdf44855ed8f0fbcd102ef05ec3012d6a4fd7c7562403f76ce6a52aeffb2b1"}, - {file = "frozenlist-1.3.3-cp39-cp39-win32.whl", hash = "sha256:efa568b885bca461f7c7b9e032655c0c143d305bf01c30caf6db2854a4532b38"}, - {file = "frozenlist-1.3.3-cp39-cp39-win_amd64.whl", hash = "sha256:cfe33efc9cb900a4c46f91a5ceba26d6df370ffddd9ca386eb1d4f0ad97b9ea9"}, - {file = "frozenlist-1.3.3.tar.gz", hash = "sha256:58bcc55721e8a90b88332d6cd441261ebb22342e238296bb330968952fbb3a6a"}, + {file = "h11-0.14.0-py3-none-any.whl", hash = "sha256:e3fe4ac4b851c468cc8363d500db52c2ead036020723024a109d37346efaa761"}, + {file = "h11-0.14.0.tar.gz", hash = "sha256:8f19fbbe99e72420ff35c00b27a34cb9937e902a8b810e2c88300c6f0a3b699d"}, ] [[package]] -name = "idna" -version = "3.4" -description = "Internationalized Domain Names in Applications (IDNA)" +name = "httpcore" +version = "1.0.5" +description = "A minimal low-level HTTP client." optional = false -python-versions = ">=3.5" +python-versions = ">=3.8" files = [ - {file = "idna-3.4-py3-none-any.whl", hash = "sha256:90b77e79eaa3eba6de819a0c442c0b4ceefc341a7a2ab77d7562bf49f425c5c2"}, - {file = "idna-3.4.tar.gz", hash = "sha256:814f528e8dead7d329833b91c5faa87d60bf71824cd12a7530b5526063d02cb4"}, + {file = "httpcore-1.0.5-py3-none-any.whl", hash = "sha256:421f18bac248b25d310f3cacd198d55b8e6125c107797b609ff9b7a6ba7991b5"}, + {file = "httpcore-1.0.5.tar.gz", hash = "sha256:34a38e2f9291467ee3b44e89dd52615370e152954ba21721378a87b2960f7a61"}, ] +[package.dependencies] +certifi = "*" +h11 = ">=0.13,<0.15" + +[package.extras] +asyncio = ["anyio (>=4.0,<5.0)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] +trio = ["trio (>=0.22.0,<0.26.0)"] + [[package]] -name = "importlib-metadata" -version = "6.7.0" -description = "Read metadata from Python packages" +name = "httpx" +version = "0.27.0" +description = "The next generation HTTP client." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "importlib_metadata-6.7.0-py3-none-any.whl", hash = "sha256:cb52082e659e97afc5dac71e79de97d8681de3aa07ff18578330904a9d18e5b5"}, - {file = "importlib_metadata-6.7.0.tar.gz", hash = "sha256:1aaf550d4f73e5d6783e7acb77aec43d49da8017410afae93822cc9cca98c4d4"}, + {file = "httpx-0.27.0-py3-none-any.whl", hash = "sha256:71d5465162c13681bff01ad59b2cc68dd838ea1f10e51574bac27103f00c91a5"}, + {file = "httpx-0.27.0.tar.gz", hash = "sha256:a0cb88a46f32dc874e04ee956e4c2764aba2aa228f650b06788ba6bda2962ab5"}, ] [package.dependencies] -typing-extensions = {version = ">=3.6.4", markers = "python_version < \"3.8\""} -zipp = ">=0.5" +anyio = "*" +certifi = "*" +httpcore = "==1.*" +idna = "*" +sniffio = "*" [package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -perf = ["ipython"] -testing = ["flufl.flake8", "importlib-resources (>=1.3)", "packaging", "pyfakefs", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-mypy (>=0.9.1)", "pytest-perf (>=0.9.2)", "pytest-ruff"] +brotli = ["brotli", "brotlicffi"] +cli = ["click (==8.*)", "pygments (==2.*)", "rich (>=10,<14)"] +http2 = ["h2 (>=3,<5)"] +socks = ["socksio (==1.*)"] + +[[package]] +name = "idna" +version = "3.7" +description = "Internationalized Domain Names in Applications (IDNA)" +optional = false +python-versions = ">=3.5" +files = [ + {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, + {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, +] [[package]] name = "iniconfig" @@ -523,134 +648,220 @@ files = [ {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, ] +[[package]] +name = "jiter" +version = "0.5.0" +description = "Fast iterable JSON parser." +optional = false +python-versions = ">=3.8" +files = [ + {file = "jiter-0.5.0-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:b599f4e89b3def9a94091e6ee52e1d7ad7bc33e238ebb9c4c63f211d74822c3f"}, + {file = "jiter-0.5.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:2a063f71c4b06225543dddadbe09d203dc0c95ba352d8b85f1221173480a71d5"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:acc0d5b8b3dd12e91dd184b87273f864b363dfabc90ef29a1092d269f18c7e28"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c22541f0b672f4d741382a97c65609332a783501551445ab2df137ada01e019e"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:63314832e302cc10d8dfbda0333a384bf4bcfce80d65fe99b0f3c0da8945a91a"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a25fbd8a5a58061e433d6fae6d5298777c0814a8bcefa1e5ecfff20c594bd749"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:503b2c27d87dfff5ab717a8200fbbcf4714516c9d85558048b1fc14d2de7d8dc"}, + {file = "jiter-0.5.0-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:6d1f3d27cce923713933a844872d213d244e09b53ec99b7a7fdf73d543529d6d"}, + {file = "jiter-0.5.0-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:c95980207b3998f2c3b3098f357994d3fd7661121f30669ca7cb945f09510a87"}, + {file = "jiter-0.5.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:afa66939d834b0ce063f57d9895e8036ffc41c4bd90e4a99631e5f261d9b518e"}, + {file = "jiter-0.5.0-cp310-none-win32.whl", hash = "sha256:f16ca8f10e62f25fd81d5310e852df6649af17824146ca74647a018424ddeccf"}, + {file = "jiter-0.5.0-cp310-none-win_amd64.whl", hash = "sha256:b2950e4798e82dd9176935ef6a55cf6a448b5c71515a556da3f6b811a7844f1e"}, + {file = "jiter-0.5.0-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d4c8e1ed0ef31ad29cae5ea16b9e41529eb50a7fba70600008e9f8de6376d553"}, + {file = "jiter-0.5.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:c6f16e21276074a12d8421692515b3fd6d2ea9c94fd0734c39a12960a20e85f3"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5280e68e7740c8c128d3ae5ab63335ce6d1fb6603d3b809637b11713487af9e6"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:583c57fc30cc1fec360e66323aadd7fc3edeec01289bfafc35d3b9dcb29495e4"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:26351cc14507bdf466b5f99aba3df3143a59da75799bf64a53a3ad3155ecded9"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4829df14d656b3fb87e50ae8b48253a8851c707da9f30d45aacab2aa2ba2d614"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a42a4bdcf7307b86cb863b2fb9bb55029b422d8f86276a50487982d99eed7c6e"}, + {file = "jiter-0.5.0-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:04d461ad0aebf696f8da13c99bc1b3e06f66ecf6cfd56254cc402f6385231c06"}, + {file = "jiter-0.5.0-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:e6375923c5f19888c9226582a124b77b622f8fd0018b843c45eeb19d9701c403"}, + {file = "jiter-0.5.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2cec323a853c24fd0472517113768c92ae0be8f8c384ef4441d3632da8baa646"}, + {file = "jiter-0.5.0-cp311-none-win32.whl", hash = "sha256:aa1db0967130b5cab63dfe4d6ff547c88b2a394c3410db64744d491df7f069bb"}, + {file = "jiter-0.5.0-cp311-none-win_amd64.whl", hash = "sha256:aa9d2b85b2ed7dc7697597dcfaac66e63c1b3028652f751c81c65a9f220899ae"}, + {file = "jiter-0.5.0-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:9f664e7351604f91dcdd557603c57fc0d551bc65cc0a732fdacbf73ad335049a"}, + {file = "jiter-0.5.0-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:044f2f1148b5248ad2c8c3afb43430dccf676c5a5834d2f5089a4e6c5bbd64df"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:702e3520384c88b6e270c55c772d4bd6d7b150608dcc94dea87ceba1b6391248"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:528d742dcde73fad9d63e8242c036ab4a84389a56e04efd854062b660f559544"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8cf80e5fe6ab582c82f0c3331df27a7e1565e2dcf06265afd5173d809cdbf9ba"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:44dfc9ddfb9b51a5626568ef4e55ada462b7328996294fe4d36de02fce42721f"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c451f7922992751a936b96c5f5b9bb9312243d9b754c34b33d0cb72c84669f4e"}, + {file = "jiter-0.5.0-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:308fce789a2f093dca1ff91ac391f11a9f99c35369117ad5a5c6c4903e1b3e3a"}, + {file = "jiter-0.5.0-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:7f5ad4a7c6b0d90776fdefa294f662e8a86871e601309643de30bf94bb93a64e"}, + {file = "jiter-0.5.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:ea189db75f8eca08807d02ae27929e890c7d47599ce3d0a6a5d41f2419ecf338"}, + {file = "jiter-0.5.0-cp312-none-win32.whl", hash = "sha256:e3bbe3910c724b877846186c25fe3c802e105a2c1fc2b57d6688b9f8772026e4"}, + {file = "jiter-0.5.0-cp312-none-win_amd64.whl", hash = "sha256:a586832f70c3f1481732919215f36d41c59ca080fa27a65cf23d9490e75b2ef5"}, + {file = "jiter-0.5.0-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:f04bc2fc50dc77be9d10f73fcc4e39346402ffe21726ff41028f36e179b587e6"}, + {file = "jiter-0.5.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:6f433a4169ad22fcb550b11179bb2b4fd405de9b982601914ef448390b2954f3"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ad4a6398c85d3a20067e6c69890ca01f68659da94d74c800298581724e426c7e"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:6baa88334e7af3f4d7a5c66c3a63808e5efbc3698a1c57626541ddd22f8e4fbf"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:1ece0a115c05efca597c6d938f88c9357c843f8c245dbbb53361a1c01afd7148"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:335942557162ad372cc367ffaf93217117401bf930483b4b3ebdb1223dbddfa7"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:649b0ee97a6e6da174bffcb3c8c051a5935d7d4f2f52ea1583b5b3e7822fbf14"}, + {file = "jiter-0.5.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:f4be354c5de82157886ca7f5925dbda369b77344b4b4adf2723079715f823989"}, + {file = "jiter-0.5.0-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5206144578831a6de278a38896864ded4ed96af66e1e63ec5dd7f4a1fce38a3a"}, + {file = "jiter-0.5.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8120c60f8121ac3d6f072b97ef0e71770cc72b3c23084c72c4189428b1b1d3b6"}, + {file = "jiter-0.5.0-cp38-none-win32.whl", hash = "sha256:6f1223f88b6d76b519cb033a4d3687ca157c272ec5d6015c322fc5b3074d8a5e"}, + {file = "jiter-0.5.0-cp38-none-win_amd64.whl", hash = "sha256:c59614b225d9f434ea8fc0d0bec51ef5fa8c83679afedc0433905994fb36d631"}, + {file = "jiter-0.5.0-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:0af3838cfb7e6afee3f00dc66fa24695199e20ba87df26e942820345b0afc566"}, + {file = "jiter-0.5.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:550b11d669600dbc342364fd4adbe987f14d0bbedaf06feb1b983383dcc4b961"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:489875bf1a0ffb3cb38a727b01e6673f0f2e395b2aad3c9387f94187cb214bbf"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:b250ca2594f5599ca82ba7e68785a669b352156260c5362ea1b4e04a0f3e2389"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8ea18e01f785c6667ca15407cd6dabbe029d77474d53595a189bdc813347218e"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:462a52be85b53cd9bffd94e2d788a09984274fe6cebb893d6287e1c296d50653"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:92cc68b48d50fa472c79c93965e19bd48f40f207cb557a8346daa020d6ba973b"}, + {file = "jiter-0.5.0-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:1c834133e59a8521bc87ebcad773608c6fa6ab5c7a022df24a45030826cf10bc"}, + {file = "jiter-0.5.0-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:ab3a71ff31cf2d45cb216dc37af522d335211f3a972d2fe14ea99073de6cb104"}, + {file = "jiter-0.5.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:cccd3af9c48ac500c95e1bcbc498020c87e1781ff0345dd371462d67b76643eb"}, + {file = "jiter-0.5.0-cp39-none-win32.whl", hash = "sha256:368084d8d5c4fc40ff7c3cc513c4f73e02c85f6009217922d0823a48ee7adf61"}, + {file = "jiter-0.5.0-cp39-none-win_amd64.whl", hash = "sha256:ce03f7b4129eb72f1687fa11300fbf677b02990618428934662406d2a76742a1"}, + {file = "jiter-0.5.0.tar.gz", hash = "sha256:1d916ba875bcab5c5f7d927df998c4cb694d27dceddf3392e58beaf10563368a"}, +] + [[package]] name = "multidict" -version = "6.0.4" +version = "6.0.5" description = "multidict implementation" optional = false python-versions = ">=3.7" files = [ - {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8"}, - {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:eeb6dcc05e911516ae3d1f207d4b0520d07f54484c49dfc294d6e7d63b734171"}, - {file = "multidict-6.0.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:d6d635d5209b82a3492508cf5b365f3446afb65ae7ebd755e70e18f287b0adf7"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c048099e4c9e9d615545e2001d3d8a4380bd403e1a0578734e0d31703d1b0c0b"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ea20853c6dbbb53ed34cb4d080382169b6f4554d394015f1bef35e881bf83547"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:16d232d4e5396c2efbbf4f6d4df89bfa905eb0d4dc5b3549d872ab898451f569"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:36c63aaa167f6c6b04ef2c85704e93af16c11d20de1d133e39de6a0e84582a93"}, - {file = "multidict-6.0.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:64bdf1086b6043bf519869678f5f2757f473dee970d7abf6da91ec00acb9cb98"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:43644e38f42e3af682690876cff722d301ac585c5b9e1eacc013b7a3f7b696a0"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:7582a1d1030e15422262de9f58711774e02fa80df0d1578995c76214f6954988"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:ddff9c4e225a63a5afab9dd15590432c22e8057e1a9a13d28ed128ecf047bbdc"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:ee2a1ece51b9b9e7752e742cfb661d2a29e7bcdba2d27e66e28a99f1890e4fa0"}, - {file = "multidict-6.0.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a2e4369eb3d47d2034032a26c7a80fcb21a2cb22e1173d761a162f11e562caa5"}, - {file = "multidict-6.0.4-cp310-cp310-win32.whl", hash = "sha256:574b7eae1ab267e5f8285f0fe881f17efe4b98c39a40858247720935b893bba8"}, - {file = "multidict-6.0.4-cp310-cp310-win_amd64.whl", hash = "sha256:4dcbb0906e38440fa3e325df2359ac6cb043df8e58c965bb45f4e406ecb162cc"}, - {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:0dfad7a5a1e39c53ed00d2dd0c2e36aed4650936dc18fd9a1826a5ae1cad6f03"}, - {file = "multidict-6.0.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:64da238a09d6039e3bd39bb3aee9c21a5e34f28bfa5aa22518581f910ff94af3"}, - {file = "multidict-6.0.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:ff959bee35038c4624250473988b24f846cbeb2c6639de3602c073f10410ceba"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:01a3a55bd90018c9c080fbb0b9f4891db37d148a0a18722b42f94694f8b6d4c9"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c5cb09abb18c1ea940fb99360ea0396f34d46566f157122c92dfa069d3e0e982"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:666daae833559deb2d609afa4490b85830ab0dfca811a98b70a205621a6109fe"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11bdf3f5e1518b24530b8241529d2050014c884cf18b6fc69c0c2b30ca248710"}, - {file = "multidict-6.0.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d18748f2d30f94f498e852c67d61261c643b349b9d2a581131725595c45ec6c"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:458f37be2d9e4c95e2d8866a851663cbc76e865b78395090786f6cd9b3bbf4f4"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:b1a2eeedcead3a41694130495593a559a668f382eee0727352b9a41e1c45759a"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:7d6ae9d593ef8641544d6263c7fa6408cc90370c8cb2bbb65f8d43e5b0351d9c"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:5979b5632c3e3534e42ca6ff856bb24b2e3071b37861c2c727ce220d80eee9ed"}, - {file = "multidict-6.0.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:dcfe792765fab89c365123c81046ad4103fcabbc4f56d1c1997e6715e8015461"}, - {file = "multidict-6.0.4-cp311-cp311-win32.whl", hash = "sha256:3601a3cece3819534b11d4efc1eb76047488fddd0c85a3948099d5da4d504636"}, - {file = "multidict-6.0.4-cp311-cp311-win_amd64.whl", hash = "sha256:81a4f0b34bd92df3da93315c6a59034df95866014ac08535fc819f043bfd51f0"}, - {file = "multidict-6.0.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:67040058f37a2a51ed8ea8f6b0e6ee5bd78ca67f169ce6122f3e2ec80dfe9b78"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:853888594621e6604c978ce2a0444a1e6e70c8d253ab65ba11657659dcc9100f"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:39ff62e7d0f26c248b15e364517a72932a611a9b75f35b45be078d81bdb86603"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:af048912e045a2dc732847d33821a9d84ba553f5c5f028adbd364dd4765092ac"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b1e8b901e607795ec06c9e42530788c45ac21ef3aaa11dbd0c69de543bfb79a9"}, - {file = "multidict-6.0.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:62501642008a8b9871ddfccbf83e4222cf8ac0d5aeedf73da36153ef2ec222d2"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:99b76c052e9f1bc0721f7541e5e8c05db3941eb9ebe7b8553c625ef88d6eefde"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:509eac6cf09c794aa27bcacfd4d62c885cce62bef7b2c3e8b2e49d365b5003fe"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:21a12c4eb6ddc9952c415f24eef97e3e55ba3af61f67c7bc388dcdec1404a067"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:5cad9430ab3e2e4fa4a2ef4450f548768400a2ac635841bc2a56a2052cdbeb87"}, - {file = "multidict-6.0.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab55edc2e84460694295f401215f4a58597f8f7c9466faec545093045476327d"}, - {file = "multidict-6.0.4-cp37-cp37m-win32.whl", hash = "sha256:5a4dcf02b908c3b8b17a45fb0f15b695bf117a67b76b7ad18b73cf8e92608775"}, - {file = "multidict-6.0.4-cp37-cp37m-win_amd64.whl", hash = "sha256:6ed5f161328b7df384d71b07317f4d8656434e34591f20552c7bcef27b0ab88e"}, - {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5fc1b16f586f049820c5c5b17bb4ee7583092fa0d1c4e28b5239181ff9532e0c"}, - {file = "multidict-6.0.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:1502e24330eb681bdaa3eb70d6358e818e8e8f908a22a1851dfd4e15bc2f8161"}, - {file = "multidict-6.0.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b692f419760c0e65d060959df05f2a531945af31fda0c8a3b3195d4efd06de11"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:45e1ecb0379bfaab5eef059f50115b54571acfbe422a14f668fc8c27ba410e7e"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:ddd3915998d93fbcd2566ddf9cf62cdb35c9e093075f862935573d265cf8f65d"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:59d43b61c59d82f2effb39a93c48b845efe23a3852d201ed2d24ba830d0b4cf2"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cc8e1d0c705233c5dd0c5e6460fbad7827d5d36f310a0fadfd45cc3029762258"}, - {file = "multidict-6.0.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d6aa0418fcc838522256761b3415822626f866758ee0bc6632c9486b179d0b52"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:6748717bb10339c4760c1e63da040f5f29f5ed6e59d76daee30305894069a660"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:4d1a3d7ef5e96b1c9e92f973e43aa5e5b96c659c9bc3124acbbd81b0b9c8a951"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:4372381634485bec7e46718edc71528024fcdc6f835baefe517b34a33c731d60"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:fc35cb4676846ef752816d5be2193a1e8367b4c1397b74a565a9d0389c433a1d"}, - {file = "multidict-6.0.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:4b9d9e4e2b37daddb5c23ea33a3417901fa7c7b3dee2d855f63ee67a0b21e5b1"}, - {file = "multidict-6.0.4-cp38-cp38-win32.whl", hash = "sha256:e41b7e2b59679edfa309e8db64fdf22399eec4b0b24694e1b2104fb789207779"}, - {file = "multidict-6.0.4-cp38-cp38-win_amd64.whl", hash = "sha256:d6c254ba6e45d8e72739281ebc46ea5eb5f101234f3ce171f0e9f5cc86991480"}, - {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:16ab77bbeb596e14212e7bab8429f24c1579234a3a462105cda4a66904998664"}, - {file = "multidict-6.0.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bc779e9e6f7fda81b3f9aa58e3a6091d49ad528b11ed19f6621408806204ad35"}, - {file = "multidict-6.0.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:4ceef517eca3e03c1cceb22030a3e39cb399ac86bff4e426d4fc6ae49052cc60"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:281af09f488903fde97923c7744bb001a9b23b039a909460d0f14edc7bf59706"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:52f2dffc8acaba9a2f27174c41c9e57f60b907bb9f096b36b1a1f3be71c6284d"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b41156839806aecb3641f3208c0dafd3ac7775b9c4c422d82ee2a45c34ba81ca"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5e3fc56f88cc98ef8139255cf8cd63eb2c586531e43310ff859d6bb3a6b51f1"}, - {file = "multidict-6.0.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8316a77808c501004802f9beebde51c9f857054a0c871bd6da8280e718444449"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:f70b98cd94886b49d91170ef23ec5c0e8ebb6f242d734ed7ed677b24d50c82cf"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:bf6774e60d67a9efe02b3616fee22441d86fab4c6d335f9d2051d19d90a40063"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:e69924bfcdda39b722ef4d9aa762b2dd38e4632b3641b1d9a57ca9cd18f2f83a"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:6b181d8c23da913d4ff585afd1155a0e1194c0b50c54fcfe286f70cdaf2b7176"}, - {file = "multidict-6.0.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:52509b5be062d9eafc8170e53026fbc54cf3b32759a23d07fd935fb04fc22d95"}, - {file = "multidict-6.0.4-cp39-cp39-win32.whl", hash = "sha256:27c523fbfbdfd19c6867af7346332b62b586eed663887392cff78d614f9ec313"}, - {file = "multidict-6.0.4-cp39-cp39-win_amd64.whl", hash = "sha256:33029f5734336aa0d4c0384525da0387ef89148dc7191aae00ca5fb23d7aafc2"}, - {file = "multidict-6.0.4.tar.gz", hash = "sha256:3666906492efb76453c0e7b97f2cf459b0682e7402c0489a95484965dbc1da49"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b644ae063c10e7f324ab1ab6b548bdf6f8b47f3ec234fef1093bc2735e5f9"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:896ebdcf62683551312c30e20614305f53125750803b614e9e6ce74a96232604"}, + {file = "multidict-6.0.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:411bf8515f3be9813d06004cac41ccf7d1cd46dfe233705933dd163b60e37600"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1d147090048129ce3c453f0292e7697d333db95e52616b3793922945804a433c"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:215ed703caf15f578dca76ee6f6b21b7603791ae090fbf1ef9d865571039ade5"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c6390cf87ff6234643428991b7359b5f59cc15155695deb4eda5c777d2b880f"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:21fd81c4ebdb4f214161be351eb5bcf385426bf023041da2fd9e60681f3cebae"}, + {file = "multidict-6.0.5-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3cc2ad10255f903656017363cd59436f2111443a76f996584d1077e43ee51182"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:6939c95381e003f54cd4c5516740faba40cf5ad3eeff460c3ad1d3e0ea2549bf"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:220dd781e3f7af2c2c1053da9fa96d9cf3072ca58f057f4c5adaaa1cab8fc442"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:766c8f7511df26d9f11cd3a8be623e59cca73d44643abab3f8c8c07620524e4a"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:fe5d7785250541f7f5019ab9cba2c71169dc7d74d0f45253f8313f436458a4ef"}, + {file = "multidict-6.0.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:c1c1496e73051918fcd4f58ff2e0f2f3066d1c76a0c6aeffd9b45d53243702cc"}, + {file = "multidict-6.0.5-cp310-cp310-win32.whl", hash = "sha256:7afcdd1fc07befad18ec4523a782cde4e93e0a2bf71239894b8d61ee578c1319"}, + {file = "multidict-6.0.5-cp310-cp310-win_amd64.whl", hash = "sha256:99f60d34c048c5c2fabc766108c103612344c46e35d4ed9ae0673d33c8fb26e8"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:f285e862d2f153a70586579c15c44656f888806ed0e5b56b64489afe4a2dbfba"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:53689bb4e102200a4fafa9de9c7c3c212ab40a7ab2c8e474491914d2305f187e"}, + {file = "multidict-6.0.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:612d1156111ae11d14afaf3a0669ebf6c170dbb735e510a7438ffe2369a847fd"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:7be7047bd08accdb7487737631d25735c9a04327911de89ff1b26b81745bd4e3"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:de170c7b4fe6859beb8926e84f7d7d6c693dfe8e27372ce3b76f01c46e489fcf"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:04bde7a7b3de05732a4eb39c94574db1ec99abb56162d6c520ad26f83267de29"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:85f67aed7bb647f93e7520633d8f51d3cbc6ab96957c71272b286b2f30dc70ed"}, + {file = "multidict-6.0.5-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:425bf820055005bfc8aa9a0b99ccb52cc2f4070153e34b701acc98d201693733"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:d3eb1ceec286eba8220c26f3b0096cf189aea7057b6e7b7a2e60ed36b373b77f"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:7901c05ead4b3fb75113fb1dd33eb1253c6d3ee37ce93305acd9d38e0b5f21a4"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e0e79d91e71b9867c73323a3444724d496c037e578a0e1755ae159ba14f4f3d1"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:29bfeb0dff5cb5fdab2023a7a9947b3b4af63e9c47cae2a10ad58394b517fddc"}, + {file = "multidict-6.0.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e030047e85cbcedbfc073f71836d62dd5dadfbe7531cae27789ff66bc551bd5e"}, + {file = "multidict-6.0.5-cp311-cp311-win32.whl", hash = "sha256:2f4848aa3baa109e6ab81fe2006c77ed4d3cd1e0ac2c1fbddb7b1277c168788c"}, + {file = "multidict-6.0.5-cp311-cp311-win_amd64.whl", hash = "sha256:2faa5ae9376faba05f630d7e5e6be05be22913782b927b19d12b8145968a85ea"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:51d035609b86722963404f711db441cf7134f1889107fb171a970c9701f92e1e"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:cbebcd5bcaf1eaf302617c114aa67569dd3f090dd0ce8ba9e35e9985b41ac35b"}, + {file = "multidict-6.0.5-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:2ffc42c922dbfddb4a4c3b438eb056828719f07608af27d163191cb3e3aa6cc5"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ceb3b7e6a0135e092de86110c5a74e46bda4bd4fbfeeb3a3bcec79c0f861e450"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:79660376075cfd4b2c80f295528aa6beb2058fd289f4c9252f986751a4cd0496"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e4428b29611e989719874670fd152b6625500ad6c686d464e99f5aaeeaca175a"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d84a5c3a5f7ce6db1f999fb9438f686bc2e09d38143f2d93d8406ed2dd6b9226"}, + {file = "multidict-6.0.5-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:76c0de87358b192de7ea9649beb392f107dcad9ad27276324c24c91774ca5271"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:79a6d2ba910adb2cbafc95dad936f8b9386e77c84c35bc0add315b856d7c3abb"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:92d16a3e275e38293623ebf639c471d3e03bb20b8ebb845237e0d3664914caef"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:fb616be3538599e797a2017cccca78e354c767165e8858ab5116813146041a24"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:14c2976aa9038c2629efa2c148022ed5eb4cb939e15ec7aace7ca932f48f9ba6"}, + {file = "multidict-6.0.5-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:435a0984199d81ca178b9ae2c26ec3d49692d20ee29bc4c11a2a8d4514c67eda"}, + {file = "multidict-6.0.5-cp312-cp312-win32.whl", hash = "sha256:9fe7b0653ba3d9d65cbe7698cca585bf0f8c83dbbcc710db9c90f478e175f2d5"}, + {file = "multidict-6.0.5-cp312-cp312-win_amd64.whl", hash = "sha256:01265f5e40f5a17f8241d52656ed27192be03bfa8764d88e8220141d1e4b3556"}, + {file = "multidict-6.0.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:19fe01cea168585ba0f678cad6f58133db2aa14eccaf22f88e4a6dccadfad8b3"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6bf7a982604375a8d49b6cc1b781c1747f243d91b81035a9b43a2126c04766f5"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:107c0cdefe028703fb5dafe640a409cb146d44a6ae201e55b35a4af8e95457dd"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:403c0911cd5d5791605808b942c88a8155c2592e05332d2bf78f18697a5fa15e"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aeaf541ddbad8311a87dd695ed9642401131ea39ad7bc8cf3ef3967fd093b626"}, + {file = "multidict-6.0.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e4972624066095e52b569e02b5ca97dbd7a7ddd4294bf4e7247d52635630dd83"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d946b0a9eb8aaa590df1fe082cee553ceab173e6cb5b03239716338629c50c7a"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:b55358304d7a73d7bdf5de62494aaf70bd33015831ffd98bc498b433dfe5b10c"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:a3145cb08d8625b2d3fee1b2d596a8766352979c9bffe5d7833e0503d0f0b5e5"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d65f25da8e248202bd47445cec78e0025c0fe7582b23ec69c3b27a640dd7a8e3"}, + {file = "multidict-6.0.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:c9bf56195c6bbd293340ea82eafd0071cb3d450c703d2c93afb89f93b8386ccc"}, + {file = "multidict-6.0.5-cp37-cp37m-win32.whl", hash = "sha256:69db76c09796b313331bb7048229e3bee7928eb62bab5e071e9f7fcc4879caee"}, + {file = "multidict-6.0.5-cp37-cp37m-win_amd64.whl", hash = "sha256:fce28b3c8a81b6b36dfac9feb1de115bab619b3c13905b419ec71d03a3fc1423"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:76f067f5121dcecf0d63a67f29080b26c43c71a98b10c701b0677e4a065fbd54"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b82cc8ace10ab5bd93235dfaab2021c70637005e1ac787031f4d1da63d493c1d"}, + {file = "multidict-6.0.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:5cb241881eefd96b46f89b1a056187ea8e9ba14ab88ba632e68d7a2ecb7aadf7"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e8e94e6912639a02ce173341ff62cc1201232ab86b8a8fcc05572741a5dc7d93"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:09a892e4a9fb47331da06948690ae38eaa2426de97b4ccbfafbdcbe5c8f37ff8"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:55205d03e8a598cfc688c71ca8ea5f66447164efff8869517f175ea632c7cb7b"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:37b15024f864916b4951adb95d3a80c9431299080341ab9544ed148091b53f50"}, + {file = "multidict-6.0.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f2a1dee728b52b33eebff5072817176c172050d44d67befd681609b4746e1c2e"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:edd08e6f2f1a390bf137080507e44ccc086353c8e98c657e666c017718561b89"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:60d698e8179a42ec85172d12f50b1668254628425a6bd611aba022257cac1386"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:3d25f19500588cbc47dc19081d78131c32637c25804df8414463ec908631e453"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:4cc0ef8b962ac7a5e62b9e826bd0cd5040e7d401bc45a6835910ed699037a461"}, + {file = "multidict-6.0.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:eca2e9d0cc5a889850e9bbd68e98314ada174ff6ccd1129500103df7a94a7a44"}, + {file = "multidict-6.0.5-cp38-cp38-win32.whl", hash = "sha256:4a6a4f196f08c58c59e0b8ef8ec441d12aee4125a7d4f4fef000ccb22f8d7241"}, + {file = "multidict-6.0.5-cp38-cp38-win_amd64.whl", hash = "sha256:0275e35209c27a3f7951e1ce7aaf93ce0d163b28948444bec61dd7badc6d3f8c"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:e7be68734bd8c9a513f2b0cfd508802d6609da068f40dc57d4e3494cefc92929"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:1d9ea7a7e779d7a3561aade7d596649fbecfa5c08a7674b11b423783217933f9"}, + {file = "multidict-6.0.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ea1456df2a27c73ce51120fa2f519f1bea2f4a03a917f4a43c8707cf4cbbae1a"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cf590b134eb70629e350691ecca88eac3e3b8b3c86992042fb82e3cb1830d5e1"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5c0631926c4f58e9a5ccce555ad7747d9a9f8b10619621f22f9635f069f6233e"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dce1c6912ab9ff5f179eaf6efe7365c1f425ed690b03341911bf4939ef2f3046"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c0868d64af83169e4d4152ec612637a543f7a336e4a307b119e98042e852ad9c"}, + {file = "multidict-6.0.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:141b43360bfd3bdd75f15ed811850763555a251e38b2405967f8e25fb43f7d40"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:7df704ca8cf4a073334e0427ae2345323613e4df18cc224f647f251e5e75a527"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:6214c5a5571802c33f80e6c84713b2c79e024995b9c5897f794b43e714daeec9"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:cd6c8fca38178e12c00418de737aef1261576bd1b6e8c6134d3e729a4e858b38"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:e02021f87a5b6932fa6ce916ca004c4d441509d33bbdbeca70d05dff5e9d2479"}, + {file = "multidict-6.0.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:ebd8d160f91a764652d3e51ce0d2956b38efe37c9231cd82cfc0bed2e40b581c"}, + {file = "multidict-6.0.5-cp39-cp39-win32.whl", hash = "sha256:04da1bb8c8dbadf2a18a452639771951c662c5ad03aefe4884775454be322c9b"}, + {file = "multidict-6.0.5-cp39-cp39-win_amd64.whl", hash = "sha256:d6f6d4f185481c9669b9447bf9d9cf3b95a0e9df9d169bbc17e363b7d5487755"}, + {file = "multidict-6.0.5-py3-none-any.whl", hash = "sha256:0d63c74e3d7ab26de115c49bffc92cc77ed23395303d496eae515d4204a625e7"}, + {file = "multidict-6.0.5.tar.gz", hash = "sha256:f7e301075edaf50500f0b341543c41194d8df3ae5caf4702f2095f3ca73dd8da"}, ] [[package]] name = "mypy" -version = "1.4.1" +version = "1.11.1" description = "Optional static typing for Python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "mypy-1.4.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:566e72b0cd6598503e48ea610e0052d1b8168e60a46e0bfd34b3acf2d57f96a8"}, - {file = "mypy-1.4.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:ca637024ca67ab24a7fd6f65d280572c3794665eaf5edcc7e90a866544076878"}, - {file = "mypy-1.4.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0dde1d180cd84f0624c5dcaaa89c89775550a675aff96b5848de78fb11adabcd"}, - {file = "mypy-1.4.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:8c4d8e89aa7de683e2056a581ce63c46a0c41e31bd2b6d34144e2c80f5ea53dc"}, - {file = "mypy-1.4.1-cp310-cp310-win_amd64.whl", hash = "sha256:bfdca17c36ae01a21274a3c387a63aa1aafe72bff976522886869ef131b937f1"}, - {file = "mypy-1.4.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7549fbf655e5825d787bbc9ecf6028731973f78088fbca3a1f4145c39ef09462"}, - {file = "mypy-1.4.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:98324ec3ecf12296e6422939e54763faedbfcc502ea4a4c38502082711867258"}, - {file = "mypy-1.4.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:141dedfdbfe8a04142881ff30ce6e6653c9685b354876b12e4fe6c78598b45e2"}, - {file = "mypy-1.4.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:8207b7105829eca6f3d774f64a904190bb2231de91b8b186d21ffd98005f14a7"}, - {file = "mypy-1.4.1-cp311-cp311-win_amd64.whl", hash = "sha256:16f0db5b641ba159eff72cff08edc3875f2b62b2fa2bc24f68c1e7a4e8232d01"}, - {file = "mypy-1.4.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:470c969bb3f9a9efcedbadcd19a74ffb34a25f8e6b0e02dae7c0e71f8372f97b"}, - {file = "mypy-1.4.1-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e5952d2d18b79f7dc25e62e014fe5a23eb1a3d2bc66318df8988a01b1a037c5b"}, - {file = "mypy-1.4.1-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:190b6bab0302cec4e9e6767d3eb66085aef2a1cc98fe04936d8a42ed2ba77bb7"}, - {file = "mypy-1.4.1-cp37-cp37m-win_amd64.whl", hash = "sha256:9d40652cc4fe33871ad3338581dca3297ff5f2213d0df345bcfbde5162abf0c9"}, - {file = "mypy-1.4.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:01fd2e9f85622d981fd9063bfaef1aed6e336eaacca00892cd2d82801ab7c042"}, - {file = "mypy-1.4.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:2460a58faeea905aeb1b9b36f5065f2dc9a9c6e4c992a6499a2360c6c74ceca3"}, - {file = "mypy-1.4.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a2746d69a8196698146a3dbe29104f9eb6a2a4d8a27878d92169a6c0b74435b6"}, - {file = "mypy-1.4.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:ae704dcfaa180ff7c4cfbad23e74321a2b774f92ca77fd94ce1049175a21c97f"}, - {file = "mypy-1.4.1-cp38-cp38-win_amd64.whl", hash = "sha256:43d24f6437925ce50139a310a64b2ab048cb2d3694c84c71c3f2a1626d8101dc"}, - {file = "mypy-1.4.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c482e1246726616088532b5e964e39765b6d1520791348e6c9dc3af25b233828"}, - {file = "mypy-1.4.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:43b592511672017f5b1a483527fd2684347fdffc041c9ef53428c8dc530f79a3"}, - {file = "mypy-1.4.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:34a9239d5b3502c17f07fd7c0b2ae6b7dd7d7f6af35fbb5072c6208e76295816"}, - {file = "mypy-1.4.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5703097c4936bbb9e9bce41478c8d08edd2865e177dc4c52be759f81ee4dd26c"}, - {file = "mypy-1.4.1-cp39-cp39-win_amd64.whl", hash = "sha256:e02d700ec8d9b1859790c0475df4e4092c7bf3272a4fd2c9f33d87fac4427b8f"}, - {file = "mypy-1.4.1-py3-none-any.whl", hash = "sha256:45d32cec14e7b97af848bddd97d85ea4f0db4d5a149ed9676caa4eb2f7402bb4"}, - {file = "mypy-1.4.1.tar.gz", hash = "sha256:9bbcd9ab8ea1f2e1c8031c21445b511442cc45c89951e49bbf852cbb70755b1b"}, + {file = "mypy-1.11.1-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a32fc80b63de4b5b3e65f4be82b4cfa362a46702672aa6a0f443b4689af7008c"}, + {file = "mypy-1.11.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c1952f5ea8a5a959b05ed5f16452fddadbaae48b5d39235ab4c3fc444d5fd411"}, + {file = "mypy-1.11.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:e1e30dc3bfa4e157e53c1d17a0dad20f89dc433393e7702b813c10e200843b03"}, + {file = "mypy-1.11.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:2c63350af88f43a66d3dfeeeb8d77af34a4f07d760b9eb3a8697f0386c7590b4"}, + {file = "mypy-1.11.1-cp310-cp310-win_amd64.whl", hash = "sha256:a831671bad47186603872a3abc19634f3011d7f83b083762c942442d51c58d58"}, + {file = "mypy-1.11.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:7b6343d338390bb946d449677726edf60102a1c96079b4f002dedff375953fc5"}, + {file = "mypy-1.11.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:e4fe9f4e5e521b458d8feb52547f4bade7ef8c93238dfb5bbc790d9ff2d770ca"}, + {file = "mypy-1.11.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:886c9dbecc87b9516eff294541bf7f3655722bf22bb898ee06985cd7269898de"}, + {file = "mypy-1.11.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:fca4a60e1dd9fd0193ae0067eaeeb962f2d79e0d9f0f66223a0682f26ffcc809"}, + {file = "mypy-1.11.1-cp311-cp311-win_amd64.whl", hash = "sha256:0bd53faf56de9643336aeea1c925012837432b5faf1701ccca7fde70166ccf72"}, + {file = "mypy-1.11.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:f39918a50f74dc5969807dcfaecafa804fa7f90c9d60506835036cc1bc891dc8"}, + {file = "mypy-1.11.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:0bc71d1fb27a428139dd78621953effe0d208aed9857cb08d002280b0422003a"}, + {file = "mypy-1.11.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:b868d3bcff720dd7217c383474008ddabaf048fad8d78ed948bb4b624870a417"}, + {file = "mypy-1.11.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:a707ec1527ffcdd1c784d0924bf5cb15cd7f22683b919668a04d2b9c34549d2e"}, + {file = "mypy-1.11.1-cp312-cp312-win_amd64.whl", hash = "sha256:64f4a90e3ea07f590c5bcf9029035cf0efeae5ba8be511a8caada1a4893f5525"}, + {file = "mypy-1.11.1-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:749fd3213916f1751fff995fccf20c6195cae941dc968f3aaadf9bb4e430e5a2"}, + {file = "mypy-1.11.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:b639dce63a0b19085213ec5fdd8cffd1d81988f47a2dec7100e93564f3e8fb3b"}, + {file = "mypy-1.11.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4c956b49c5d865394d62941b109728c5c596a415e9c5b2be663dd26a1ff07bc0"}, + {file = "mypy-1.11.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:45df906e8b6804ef4b666af29a87ad9f5921aad091c79cc38e12198e220beabd"}, + {file = "mypy-1.11.1-cp38-cp38-win_amd64.whl", hash = "sha256:d44be7551689d9d47b7abc27c71257adfdb53f03880841a5db15ddb22dc63edb"}, + {file = "mypy-1.11.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:2684d3f693073ab89d76da8e3921883019ea8a3ec20fa5d8ecca6a2db4c54bbe"}, + {file = "mypy-1.11.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:79c07eb282cb457473add5052b63925e5cc97dfab9812ee65a7c7ab5e3cb551c"}, + {file = "mypy-1.11.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:11965c2f571ded6239977b14deebd3f4c3abd9a92398712d6da3a772974fad69"}, + {file = "mypy-1.11.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a2b43895a0f8154df6519706d9bca8280cda52d3d9d1514b2d9c3e26792a0b74"}, + {file = "mypy-1.11.1-cp39-cp39-win_amd64.whl", hash = "sha256:1a81cf05975fd61aec5ae16501a091cfb9f605dc3e3c878c0da32f250b74760b"}, + {file = "mypy-1.11.1-py3-none-any.whl", hash = "sha256:0624bdb940255d2dd24e829d99a13cfeb72e4e9031f9492148f410ed30bcab54"}, + {file = "mypy-1.11.1.tar.gz", hash = "sha256:f404a0b069709f18bbdb702eb3dcfe51910602995de00bd39cea3050b5772d08"}, ] [package.dependencies] mypy-extensions = ">=1.0.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} -typed-ast = {version = ">=1.4.0,<2", markers = "python_version < \"3.8\""} -typing-extensions = ">=4.1.0" +typing-extensions = ">=4.6.0" [package.extras] dmypy = ["psutil (>=4.0)"] install-types = ["pip"] -python2 = ["typed-ast (>=1.4.0,<2)"] +mypyc = ["setuptools (>=50)"] reports = ["lxml"] [[package]] @@ -664,31 +875,52 @@ files = [ {file = "mypy_extensions-1.0.0.tar.gz", hash = "sha256:75dbf8955dc00442a438fc4d0666508a9a97b6bd41aa2f0ffe9d2f2725af0782"}, ] +[[package]] +name = "openai" +version = "1.41.1" +description = "The official Python library for the openai API" +optional = false +python-versions = ">=3.7.1" +files = [ + {file = "openai-1.41.1-py3-none-any.whl", hash = "sha256:56fb04105263f79559aff3ceea2e1dd16f8c5385e8238cb66cf0e6888fa8bfcf"}, + {file = "openai-1.41.1.tar.gz", hash = "sha256:e38e376efd91e0d4db071e2a6517b6b4cac1c2a6fd63efdc5ec6be10c5967c1b"}, +] + +[package.dependencies] +anyio = ">=3.5.0,<5" +distro = ">=1.7.0,<2" +httpx = ">=0.23.0,<1" +jiter = ">=0.4.0,<1" +pydantic = ">=1.9.0,<3" +sniffio = "*" +tqdm = ">4" +typing-extensions = ">=4.11,<5" + +[package.extras] +datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] + [[package]] name = "packaging" -version = "23.1" +version = "24.1" description = "Core utilities for Python packages" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "packaging-23.1-py3-none-any.whl", hash = "sha256:994793af429502c4ea2ebf6bf664629d07c1a9fe974af92966e4b8d2df7edc61"}, - {file = "packaging-23.1.tar.gz", hash = "sha256:a392980d2b6cffa644431898be54b0045151319d1e7ec34f0cfed48767dd334f"}, + {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, + {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, ] [[package]] name = "pluggy" -version = "1.2.0" +version = "1.5.0" description = "plugin and hook calling mechanisms for python" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pluggy-1.2.0-py3-none-any.whl", hash = "sha256:c2fd55a7d7a3863cba1a013e4e2414658b1d07b6bc57b3919e0c63c9abb99849"}, - {file = "pluggy-1.2.0.tar.gz", hash = "sha256:d12f0c4b579b15f5e054301bb226ee85eeeba08ffec228092f8defbaa3a4c4b3"}, + {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, + {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, ] -[package.dependencies] -importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} - [package.extras] dev = ["pre-commit", "tox"] testing = ["pytest", "pytest-benchmark"] @@ -706,55 +938,126 @@ files = [ [[package]] name = "pydantic" -version = "1.10.11" -description = "Data validation and settings management using python type hints" +version = "2.8.2" +description = "Data validation using Python type hints" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "pydantic-1.10.11-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:ff44c5e89315b15ff1f7fdaf9853770b810936d6b01a7bcecaa227d2f8fe444f"}, - {file = "pydantic-1.10.11-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:a6c098d4ab5e2d5b3984d3cb2527e2d6099d3de85630c8934efcfdc348a9760e"}, - {file = "pydantic-1.10.11-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:16928fdc9cb273c6af00d9d5045434c39afba5f42325fb990add2c241402d151"}, - {file = "pydantic-1.10.11-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:0588788a9a85f3e5e9ebca14211a496409cb3deca5b6971ff37c556d581854e7"}, - {file = "pydantic-1.10.11-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e9baf78b31da2dc3d3f346ef18e58ec5f12f5aaa17ac517e2ffd026a92a87588"}, - {file = "pydantic-1.10.11-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:373c0840f5c2b5b1ccadd9286782852b901055998136287828731868027a724f"}, - {file = "pydantic-1.10.11-cp310-cp310-win_amd64.whl", hash = "sha256:c3339a46bbe6013ef7bdd2844679bfe500347ac5742cd4019a88312aa58a9847"}, - {file = "pydantic-1.10.11-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:08a6c32e1c3809fbc49debb96bf833164f3438b3696abf0fbeceb417d123e6eb"}, - {file = "pydantic-1.10.11-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a451ccab49971af043ec4e0d207cbc8cbe53dbf148ef9f19599024076fe9c25b"}, - {file = "pydantic-1.10.11-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5b02d24f7b2b365fed586ed73582c20f353a4c50e4be9ba2c57ab96f8091ddae"}, - {file = "pydantic-1.10.11-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3f34739a89260dfa420aa3cbd069fbcc794b25bbe5c0a214f8fb29e363484b66"}, - {file = "pydantic-1.10.11-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:e297897eb4bebde985f72a46a7552a7556a3dd11e7f76acda0c1093e3dbcf216"}, - {file = "pydantic-1.10.11-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:d185819a7a059550ecb85d5134e7d40f2565f3dd94cfd870132c5f91a89cf58c"}, - {file = "pydantic-1.10.11-cp311-cp311-win_amd64.whl", hash = "sha256:4400015f15c9b464c9db2d5d951b6a780102cfa5870f2c036d37c23b56f7fc1b"}, - {file = "pydantic-1.10.11-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2417de68290434461a266271fc57274a138510dca19982336639484c73a07af6"}, - {file = "pydantic-1.10.11-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:331c031ba1554b974c98679bd0780d89670d6fd6f53f5d70b10bdc9addee1713"}, - {file = "pydantic-1.10.11-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8268a735a14c308923e8958363e3a3404f6834bb98c11f5ab43251a4e410170c"}, - {file = "pydantic-1.10.11-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:44e51ba599c3ef227e168424e220cd3e544288c57829520dc90ea9cb190c3248"}, - {file = "pydantic-1.10.11-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d7781f1d13b19700b7949c5a639c764a077cbbdd4322ed505b449d3ca8edcb36"}, - {file = "pydantic-1.10.11-cp37-cp37m-win_amd64.whl", hash = "sha256:7522a7666157aa22b812ce14c827574ddccc94f361237ca6ea8bb0d5c38f1629"}, - {file = "pydantic-1.10.11-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:bc64eab9b19cd794a380179ac0e6752335e9555d214cfcb755820333c0784cb3"}, - {file = "pydantic-1.10.11-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:8dc77064471780262b6a68fe67e013298d130414d5aaf9b562c33987dbd2cf4f"}, - {file = "pydantic-1.10.11-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fe429898f2c9dd209bd0632a606bddc06f8bce081bbd03d1c775a45886e2c1cb"}, - {file = "pydantic-1.10.11-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:192c608ad002a748e4a0bed2ddbcd98f9b56df50a7c24d9a931a8c5dd053bd3d"}, - {file = "pydantic-1.10.11-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:ef55392ec4bb5721f4ded1096241e4b7151ba6d50a50a80a2526c854f42e6a2f"}, - {file = "pydantic-1.10.11-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:41e0bb6efe86281623abbeeb0be64eab740c865388ee934cd3e6a358784aca6e"}, - {file = "pydantic-1.10.11-cp38-cp38-win_amd64.whl", hash = "sha256:265a60da42f9f27e0b1014eab8acd3e53bd0bad5c5b4884e98a55f8f596b2c19"}, - {file = "pydantic-1.10.11-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:469adf96c8e2c2bbfa655fc7735a2a82f4c543d9fee97bd113a7fb509bf5e622"}, - {file = "pydantic-1.10.11-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e6cbfbd010b14c8a905a7b10f9fe090068d1744d46f9e0c021db28daeb8b6de1"}, - {file = "pydantic-1.10.11-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:abade85268cc92dff86d6effcd917893130f0ff516f3d637f50dadc22ae93999"}, - {file = "pydantic-1.10.11-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e9738b0f2e6c70f44ee0de53f2089d6002b10c33264abee07bdb5c7f03038303"}, - {file = "pydantic-1.10.11-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:787cf23e5a0cde753f2eabac1b2e73ae3844eb873fd1f5bdbff3048d8dbb7604"}, - {file = "pydantic-1.10.11-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:174899023337b9fc685ac8adaa7b047050616136ccd30e9070627c1aaab53a13"}, - {file = "pydantic-1.10.11-cp39-cp39-win_amd64.whl", hash = "sha256:1954f8778489a04b245a1e7b8b22a9d3ea8ef49337285693cf6959e4b757535e"}, - {file = "pydantic-1.10.11-py3-none-any.whl", hash = "sha256:008c5e266c8aada206d0627a011504e14268a62091450210eda7c07fabe6963e"}, - {file = "pydantic-1.10.11.tar.gz", hash = "sha256:f66d479cf7eb331372c470614be6511eae96f1f120344c25f3f9bb59fb1b5528"}, + {file = "pydantic-2.8.2-py3-none-any.whl", hash = "sha256:73ee9fddd406dc318b885c7a2eab8a6472b68b8fb5ba8150949fc3db939f23c8"}, + {file = "pydantic-2.8.2.tar.gz", hash = "sha256:6f62c13d067b0755ad1c21a34bdd06c0c12625a22b0fc09c6b149816604f7c2a"}, ] [package.dependencies] -typing-extensions = ">=4.2.0" +annotated-types = ">=0.4.0" +pydantic-core = "2.20.1" +typing-extensions = [ + {version = ">=4.12.2", markers = "python_version >= \"3.13\""}, + {version = ">=4.6.1", markers = "python_version < \"3.13\""}, +] [package.extras] -dotenv = ["python-dotenv (>=0.10.4)"] -email = ["email-validator (>=1.0.3)"] +email = ["email-validator (>=2.0.0)"] + +[[package]] +name = "pydantic-core" +version = "2.20.1" +description = "Core functionality for Pydantic validation and serialization" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pydantic_core-2.20.1-cp310-cp310-macosx_10_12_x86_64.whl", hash = "sha256:3acae97ffd19bf091c72df4d726d552c473f3576409b2a7ca36b2f535ffff4a3"}, + {file = "pydantic_core-2.20.1-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:41f4c96227a67a013e7de5ff8f20fb496ce573893b7f4f2707d065907bffdbd6"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f239eb799a2081495ea659d8d4a43a8f42cd1fe9ff2e7e436295c38a10c286a"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:53e431da3fc53360db73eedf6f7124d1076e1b4ee4276b36fb25514544ceb4a3"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:f1f62b2413c3a0e846c3b838b2ecd6c7a19ec6793b2a522745b0869e37ab5bc1"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5d41e6daee2813ecceea8eda38062d69e280b39df793f5a942fa515b8ed67953"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3d482efec8b7dc6bfaedc0f166b2ce349df0011f5d2f1f25537ced4cfc34fd98"}, + {file = "pydantic_core-2.20.1-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:e93e1a4b4b33daed65d781a57a522ff153dcf748dee70b40c7258c5861e1768a"}, + {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:e7c4ea22b6739b162c9ecaaa41d718dfad48a244909fe7ef4b54c0b530effc5a"}, + {file = "pydantic_core-2.20.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:4f2790949cf385d985a31984907fecb3896999329103df4e4983a4a41e13e840"}, + {file = "pydantic_core-2.20.1-cp310-none-win32.whl", hash = "sha256:5e999ba8dd90e93d57410c5e67ebb67ffcaadcea0ad973240fdfd3a135506250"}, + {file = "pydantic_core-2.20.1-cp310-none-win_amd64.whl", hash = "sha256:512ecfbefef6dac7bc5eaaf46177b2de58cdf7acac8793fe033b24ece0b9566c"}, + {file = "pydantic_core-2.20.1-cp311-cp311-macosx_10_12_x86_64.whl", hash = "sha256:d2a8fa9d6d6f891f3deec72f5cc668e6f66b188ab14bb1ab52422fe8e644f312"}, + {file = "pydantic_core-2.20.1-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:175873691124f3d0da55aeea1d90660a6ea7a3cfea137c38afa0a5ffabe37b88"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37eee5b638f0e0dcd18d21f59b679686bbd18917b87db0193ae36f9c23c355fc"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:25e9185e2d06c16ee438ed39bf62935ec436474a6ac4f9358524220f1b236e43"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:150906b40ff188a3260cbee25380e7494ee85048584998c1e66df0c7a11c17a6"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ad4aeb3e9a97286573c03df758fc7627aecdd02f1da04516a86dc159bf70121"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d3f3ed29cd9f978c604708511a1f9c2fdcb6c38b9aae36a51905b8811ee5cbf1"}, + {file = "pydantic_core-2.20.1-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:b0dae11d8f5ded51699c74d9548dcc5938e0804cc8298ec0aa0da95c21fff57b"}, + {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:faa6b09ee09433b87992fb5a2859efd1c264ddc37280d2dd5db502126d0e7f27"}, + {file = "pydantic_core-2.20.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:9dc1b507c12eb0481d071f3c1808f0529ad41dc415d0ca11f7ebfc666e66a18b"}, + {file = "pydantic_core-2.20.1-cp311-none-win32.whl", hash = "sha256:fa2fddcb7107e0d1808086ca306dcade7df60a13a6c347a7acf1ec139aa6789a"}, + {file = "pydantic_core-2.20.1-cp311-none-win_amd64.whl", hash = "sha256:40a783fb7ee353c50bd3853e626f15677ea527ae556429453685ae32280c19c2"}, + {file = "pydantic_core-2.20.1-cp312-cp312-macosx_10_12_x86_64.whl", hash = "sha256:595ba5be69b35777474fa07f80fc260ea71255656191adb22a8c53aba4479231"}, + {file = "pydantic_core-2.20.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:a4f55095ad087474999ee28d3398bae183a66be4823f753cd7d67dd0153427c9"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f9aa05d09ecf4c75157197f27cdc9cfaeb7c5f15021c6373932bf3e124af029f"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:e97fdf088d4b31ff4ba35db26d9cc472ac7ef4a2ff2badeabf8d727b3377fc52"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:bc633a9fe1eb87e250b5c57d389cf28998e4292336926b0b6cdaee353f89a237"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d573faf8eb7e6b1cbbcb4f5b247c60ca8be39fe2c674495df0eb4318303137fe"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26dc97754b57d2fd00ac2b24dfa341abffc380b823211994c4efac7f13b9e90e"}, + {file = "pydantic_core-2.20.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:33499e85e739a4b60c9dac710c20a08dc73cb3240c9a0e22325e671b27b70d24"}, + {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:bebb4d6715c814597f85297c332297c6ce81e29436125ca59d1159b07f423eb1"}, + {file = "pydantic_core-2.20.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:516d9227919612425c8ef1c9b869bbbee249bc91912c8aaffb66116c0b447ebd"}, + {file = "pydantic_core-2.20.1-cp312-none-win32.whl", hash = "sha256:469f29f9093c9d834432034d33f5fe45699e664f12a13bf38c04967ce233d688"}, + {file = "pydantic_core-2.20.1-cp312-none-win_amd64.whl", hash = "sha256:035ede2e16da7281041f0e626459bcae33ed998cca6a0a007a5ebb73414ac72d"}, + {file = "pydantic_core-2.20.1-cp313-cp313-macosx_10_12_x86_64.whl", hash = "sha256:0827505a5c87e8aa285dc31e9ec7f4a17c81a813d45f70b1d9164e03a813a686"}, + {file = "pydantic_core-2.20.1-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:19c0fa39fa154e7e0b7f82f88ef85faa2a4c23cc65aae2f5aea625e3c13c735a"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4aa223cd1e36b642092c326d694d8bf59b71ddddc94cdb752bbbb1c5c91d833b"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:c336a6d235522a62fef872c6295a42ecb0c4e1d0f1a3e500fe949415761b8a19"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7eb6a0587eded33aeefea9f916899d42b1799b7b14b8f8ff2753c0ac1741edac"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:70c8daf4faca8da5a6d655f9af86faf6ec2e1768f4b8b9d0226c02f3d6209703"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e9fa4c9bf273ca41f940bceb86922a7667cd5bf90e95dbb157cbb8441008482c"}, + {file = "pydantic_core-2.20.1-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:11b71d67b4725e7e2a9f6e9c0ac1239bbc0c48cce3dc59f98635efc57d6dac83"}, + {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:270755f15174fb983890c49881e93f8f1b80f0b5e3a3cc1394a255706cabd203"}, + {file = "pydantic_core-2.20.1-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:c81131869240e3e568916ef4c307f8b99583efaa60a8112ef27a366eefba8ef0"}, + {file = "pydantic_core-2.20.1-cp313-none-win32.whl", hash = "sha256:b91ced227c41aa29c672814f50dbb05ec93536abf8f43cd14ec9521ea09afe4e"}, + {file = "pydantic_core-2.20.1-cp313-none-win_amd64.whl", hash = "sha256:65db0f2eefcaad1a3950f498aabb4875c8890438bc80b19362cf633b87a8ab20"}, + {file = "pydantic_core-2.20.1-cp38-cp38-macosx_10_12_x86_64.whl", hash = "sha256:4745f4ac52cc6686390c40eaa01d48b18997cb130833154801a442323cc78f91"}, + {file = "pydantic_core-2.20.1-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a8ad4c766d3f33ba8fd692f9aa297c9058970530a32c728a2c4bfd2616d3358b"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:41e81317dd6a0127cabce83c0c9c3fbecceae981c8391e6f1dec88a77c8a569a"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:04024d270cf63f586ad41fff13fde4311c4fc13ea74676962c876d9577bcc78f"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eaad4ff2de1c3823fddf82f41121bdf453d922e9a238642b1dedb33c4e4f98ad"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:26ab812fa0c845df815e506be30337e2df27e88399b985d0bb4e3ecfe72df31c"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3c5ebac750d9d5f2706654c638c041635c385596caf68f81342011ddfa1e5598"}, + {file = "pydantic_core-2.20.1-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2aafc5a503855ea5885559eae883978c9b6d8c8993d67766ee73d82e841300dd"}, + {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:4868f6bd7c9d98904b748a2653031fc9c2f85b6237009d475b1008bfaeb0a5aa"}, + {file = "pydantic_core-2.20.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:aa2f457b4af386254372dfa78a2eda2563680d982422641a85f271c859df1987"}, + {file = "pydantic_core-2.20.1-cp38-none-win32.whl", hash = "sha256:225b67a1f6d602de0ce7f6c1c3ae89a4aa25d3de9be857999e9124f15dab486a"}, + {file = "pydantic_core-2.20.1-cp38-none-win_amd64.whl", hash = "sha256:6b507132dcfc0dea440cce23ee2182c0ce7aba7054576efc65634f080dbe9434"}, + {file = "pydantic_core-2.20.1-cp39-cp39-macosx_10_12_x86_64.whl", hash = "sha256:b03f7941783b4c4a26051846dea594628b38f6940a2fdc0df00b221aed39314c"}, + {file = "pydantic_core-2.20.1-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:1eedfeb6089ed3fad42e81a67755846ad4dcc14d73698c120a82e4ccf0f1f9f6"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:635fee4e041ab9c479e31edda27fcf966ea9614fff1317e280d99eb3e5ab6fe2"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:77bf3ac639c1ff567ae3b47f8d4cc3dc20f9966a2a6dd2311dcc055d3d04fb8a"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:7ed1b0132f24beeec5a78b67d9388656d03e6a7c837394f99257e2d55b461611"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c6514f963b023aeee506678a1cf821fe31159b925c4b76fe2afa94cc70b3222b"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:10d4204d8ca33146e761c79f83cc861df20e7ae9f6487ca290a97702daf56006"}, + {file = "pydantic_core-2.20.1-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:2d036c7187b9422ae5b262badb87a20a49eb6c5238b2004e96d4da1231badef1"}, + {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:9ebfef07dbe1d93efb94b4700f2d278494e9162565a54f124c404a5656d7ff09"}, + {file = "pydantic_core-2.20.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:6b9d9bb600328a1ce523ab4f454859e9d439150abb0906c5a1983c146580ebab"}, + {file = "pydantic_core-2.20.1-cp39-none-win32.whl", hash = "sha256:784c1214cb6dd1e3b15dd8b91b9a53852aed16671cc3fbe4786f4f1db07089e2"}, + {file = "pydantic_core-2.20.1-cp39-none-win_amd64.whl", hash = "sha256:d2fe69c5434391727efa54b47a1e7986bb0186e72a41b203df8f5b0a19a4f669"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_10_12_x86_64.whl", hash = "sha256:a45f84b09ac9c3d35dfcf6a27fd0634d30d183205230a0ebe8373a0e8cfa0906"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-macosx_11_0_arm64.whl", hash = "sha256:d02a72df14dfdbaf228424573a07af10637bd490f0901cee872c4f434a735b94"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d2b27e6af28f07e2f195552b37d7d66b150adbaa39a6d327766ffd695799780f"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:084659fac3c83fd674596612aeff6041a18402f1e1bc19ca39e417d554468482"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:242b8feb3c493ab78be289c034a1f659e8826e2233786e36f2893a950a719bb6"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:38cf1c40a921d05c5edc61a785c0ddb4bed67827069f535d794ce6bcded919fc"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:e0bbdd76ce9aa5d4209d65f2b27fc6e5ef1312ae6c5333c26db3f5ade53a1e99"}, + {file = "pydantic_core-2.20.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:254ec27fdb5b1ee60684f91683be95e5133c994cc54e86a0b0963afa25c8f8a6"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_10_12_x86_64.whl", hash = "sha256:407653af5617f0757261ae249d3fba09504d7a71ab36ac057c938572d1bc9331"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-macosx_11_0_arm64.whl", hash = "sha256:c693e916709c2465b02ca0ad7b387c4f8423d1db7b4649c551f27a529181c5ad"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5b5ff4911aea936a47d9376fd3ab17e970cc543d1b68921886e7f64bd28308d1"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:177f55a886d74f1808763976ac4efd29b7ed15c69f4d838bbd74d9d09cf6fa86"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-manylinux_2_5_i686.manylinux1_i686.whl", hash = "sha256:964faa8a861d2664f0c7ab0c181af0bea66098b1919439815ca8803ef136fc4e"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_aarch64.whl", hash = "sha256:4dd484681c15e6b9a977c785a345d3e378d72678fd5f1f3c0509608da24f2ac0"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-musllinux_1_1_x86_64.whl", hash = "sha256:f6d6cff3538391e8486a431569b77921adfcdef14eb18fbf19b7c0a5294d4e6a"}, + {file = "pydantic_core-2.20.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:a6d511cc297ff0883bc3708b465ff82d7560193169a8b93260f74ecb0a5e08a7"}, + {file = "pydantic_core-2.20.1.tar.gz", hash = "sha256:26ca695eeee5f9f1aeeb211ffc12f10bcb6f71e2989988fda61dabd65db878d4"}, +] + +[package.dependencies] +typing-extensions = ">=4.6.0,<4.7.0 || >4.7.0" [[package]] name = "pytest" @@ -771,7 +1074,6 @@ files = [ atomicwrites = {version = ">=1.0", markers = "sys_platform == \"win32\""} attrs = ">=19.2.0" colorama = {version = "*", markers = "sys_platform == \"win32\""} -importlib-metadata = {version = ">=0.12", markers = "python_version < \"3.8\""} iniconfig = "*" packaging = "*" pluggy = ">=0.12,<2.0" @@ -794,7 +1096,6 @@ files = [ [package.dependencies] pytest = ">=6.1.0" -typing-extensions = {version = ">=4.0", markers = "python_version < \"3.8\""} [package.extras] testing = ["coverage (==6.2)", "flaky (>=3.5.0)", "hypothesis (>=5.7.1)", "mypy (==0.931)"] @@ -832,13 +1133,12 @@ files = [ attrs = ">=19.0" filelock = ">=3.0" mypy = [ - {version = ">=0.500", markers = "python_version < \"3.8\""}, - {version = ">=0.700", markers = "python_version >= \"3.8\" and python_version < \"3.9\""}, {version = ">=0.780", markers = "python_version >= \"3.9\""}, + {version = ">=0.700", markers = "python_version >= \"3.8\" and python_version < \"3.9\""}, ] pytest = [ - {version = ">=4.6", markers = "python_version >= \"3.6\" and python_version < \"3.10\""}, {version = ">=6.2", markers = "python_version >= \"3.10\""}, + {version = ">=4.6", markers = "python_version >= \"3.6\" and python_version < \"3.10\""}, ] [[package]] @@ -863,159 +1163,163 @@ regex = "*" [[package]] name = "pyyaml" -version = "6.0" +version = "6.0.2" description = "YAML parser and emitter for Python" optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "PyYAML-6.0-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:d4db7c7aef085872ef65a8fd7d6d09a14ae91f691dec3e87ee5ee0539d516f53"}, - {file = "PyYAML-6.0-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:9df7ed3b3d2e0ecfe09e14741b857df43adb5a3ddadc919a2d94fbdf78fea53c"}, - {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:77f396e6ef4c73fdc33a9157446466f1cff553d979bd00ecb64385760c6babdc"}, - {file = "PyYAML-6.0-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a80a78046a72361de73f8f395f1f1e49f956c6be882eed58505a15f3e430962b"}, - {file = "PyYAML-6.0-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:f84fbc98b019fef2ee9a1cb3ce93e3187a6df0b2538a651bfb890254ba9f90b5"}, - {file = "PyYAML-6.0-cp310-cp310-win32.whl", hash = "sha256:2cd5df3de48857ed0544b34e2d40e9fac445930039f3cfe4bcc592a1f836d513"}, - {file = "PyYAML-6.0-cp310-cp310-win_amd64.whl", hash = "sha256:daf496c58a8c52083df09b80c860005194014c3698698d1a57cbcfa182142a3a"}, - {file = "PyYAML-6.0-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:d4b0ba9512519522b118090257be113b9468d804b19d63c71dbcf4a48fa32358"}, - {file = "PyYAML-6.0-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:81957921f441d50af23654aa6c5e5eaf9b06aba7f0a19c18a538dc7ef291c5a1"}, - {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:afa17f5bc4d1b10afd4466fd3a44dc0e245382deca5b3c353d8b757f9e3ecb8d"}, - {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:dbad0e9d368bb989f4515da330b88a057617d16b6a8245084f1b05400f24609f"}, - {file = "PyYAML-6.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:432557aa2c09802be39460360ddffd48156e30721f5e8d917f01d31694216782"}, - {file = "PyYAML-6.0-cp311-cp311-win32.whl", hash = "sha256:bfaef573a63ba8923503d27530362590ff4f576c626d86a9fed95822a8255fd7"}, - {file = "PyYAML-6.0-cp311-cp311-win_amd64.whl", hash = "sha256:01b45c0191e6d66c470b6cf1b9531a771a83c1c4208272ead47a3ae4f2f603bf"}, - {file = "PyYAML-6.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:897b80890765f037df3403d22bab41627ca8811ae55e9a722fd0392850ec4d86"}, - {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:50602afada6d6cbfad699b0c7bb50d5ccffa7e46a3d738092afddc1f9758427f"}, - {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:48c346915c114f5fdb3ead70312bd042a953a8ce5c7106d5bfb1a5254e47da92"}, - {file = "PyYAML-6.0-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:98c4d36e99714e55cfbaaee6dd5badbc9a1ec339ebfc3b1f52e293aee6bb71a4"}, - {file = "PyYAML-6.0-cp36-cp36m-win32.whl", hash = "sha256:0283c35a6a9fbf047493e3a0ce8d79ef5030852c51e9d911a27badfde0605293"}, - {file = "PyYAML-6.0-cp36-cp36m-win_amd64.whl", hash = "sha256:07751360502caac1c067a8132d150cf3d61339af5691fe9e87803040dbc5db57"}, - {file = "PyYAML-6.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:819b3830a1543db06c4d4b865e70ded25be52a2e0631ccd2f6a47a2822f2fd7c"}, - {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:473f9edb243cb1935ab5a084eb238d842fb8f404ed2193a915d1784b5a6b5fc0"}, - {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:0ce82d761c532fe4ec3f87fc45688bdd3a4c1dc5e0b4a19814b9009a29baefd4"}, - {file = "PyYAML-6.0-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:231710d57adfd809ef5d34183b8ed1eeae3f76459c18fb4a0b373ad56bedcdd9"}, - {file = "PyYAML-6.0-cp37-cp37m-win32.whl", hash = "sha256:c5687b8d43cf58545ade1fe3e055f70eac7a5a1a0bf42824308d868289a95737"}, - {file = "PyYAML-6.0-cp37-cp37m-win_amd64.whl", hash = "sha256:d15a181d1ecd0d4270dc32edb46f7cb7733c7c508857278d3d378d14d606db2d"}, - {file = "PyYAML-6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:0b4624f379dab24d3725ffde76559cff63d9ec94e1736b556dacdfebe5ab6d4b"}, - {file = "PyYAML-6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:213c60cd50106436cc818accf5baa1aba61c0189ff610f64f4a3e8c6726218ba"}, - {file = "PyYAML-6.0-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9fa600030013c4de8165339db93d182b9431076eb98eb40ee068700c9c813e34"}, - {file = "PyYAML-6.0-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:277a0ef2981ca40581a47093e9e2d13b3f1fbbeffae064c1d21bfceba2030287"}, - {file = "PyYAML-6.0-cp38-cp38-win32.whl", hash = "sha256:d4eccecf9adf6fbcc6861a38015c2a64f38b9d94838ac1810a9023a0609e1b78"}, - {file = "PyYAML-6.0-cp38-cp38-win_amd64.whl", hash = "sha256:1e4747bc279b4f613a09eb64bba2ba602d8a6664c6ce6396a4d0cd413a50ce07"}, - {file = "PyYAML-6.0-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:055d937d65826939cb044fc8c9b08889e8c743fdc6a32b33e2390f66013e449b"}, - {file = "PyYAML-6.0-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:e61ceaab6f49fb8bdfaa0f92c4b57bcfbea54c09277b1b4f7ac376bfb7a7c174"}, - {file = "PyYAML-6.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d67d839ede4ed1b28a4e8909735fc992a923cdb84e618544973d7dfc71540803"}, - {file = "PyYAML-6.0-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:cba8c411ef271aa037d7357a2bc8f9ee8b58b9965831d9e51baf703280dc73d3"}, - {file = "PyYAML-6.0-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:40527857252b61eacd1d9af500c3337ba8deb8fc298940291486c465c8b46ec0"}, - {file = "PyYAML-6.0-cp39-cp39-win32.whl", hash = "sha256:b5b9eccad747aabaaffbc6064800670f0c297e52c12754eb1d976c57e4f74dcb"}, - {file = "PyYAML-6.0-cp39-cp39-win_amd64.whl", hash = "sha256:b3d267842bf12586ba6c734f89d1f5b871df0273157918b0ccefa29deb05c21c"}, - {file = "PyYAML-6.0.tar.gz", hash = "sha256:68fb519c14306fec9720a2a5b45bc9f0c8d1b9c72adf45c37baedfcd949c35a2"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:0a9a2848a5b7feac301353437eb7d5957887edbf81d56e903999a75a3d743086"}, + {file = "PyYAML-6.0.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:29717114e51c84ddfba879543fb232a6ed60086602313ca38cce623c1d62cfbf"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:8824b5a04a04a047e72eea5cec3bc266db09e35de6bdfe34c9436ac5ee27d237"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7c36280e6fb8385e520936c3cb3b8042851904eba0e58d277dca80a5cfed590b"}, + {file = "PyYAML-6.0.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ec031d5d2feb36d1d1a24380e4db6d43695f3748343d99434e6f5f9156aaa2ed"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:936d68689298c36b53b29f23c6dbb74de12b4ac12ca6cfe0e047bedceea56180"}, + {file = "PyYAML-6.0.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:23502f431948090f597378482b4812b0caae32c22213aecf3b55325e049a6c68"}, + {file = "PyYAML-6.0.2-cp310-cp310-win32.whl", hash = "sha256:2e99c6826ffa974fe6e27cdb5ed0021786b03fc98e5ee3c5bfe1fd5015f42b99"}, + {file = "PyYAML-6.0.2-cp310-cp310-win_amd64.whl", hash = "sha256:a4d3091415f010369ae4ed1fc6b79def9416358877534caf6a0fdd2146c87a3e"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:cc1c1159b3d456576af7a3e4d1ba7e6924cb39de8f67111c735f6fc832082774"}, + {file = "PyYAML-6.0.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:1e2120ef853f59c7419231f3bf4e7021f1b936f6ebd222406c3b60212205d2ee"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5d225db5a45f21e78dd9358e58a98702a0302f2659a3c6cd320564b75b86f47c"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5ac9328ec4831237bec75defaf839f7d4564be1e6b25ac710bd1a96321cc8317"}, + {file = "PyYAML-6.0.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3ad2a3decf9aaba3d29c8f537ac4b243e36bef957511b4766cb0057d32b0be85"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ff3824dc5261f50c9b0dfb3be22b4567a6f938ccce4587b38952d85fd9e9afe4"}, + {file = "PyYAML-6.0.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:797b4f722ffa07cc8d62053e4cff1486fa6dc094105d13fea7b1de7d8bf71c9e"}, + {file = "PyYAML-6.0.2-cp311-cp311-win32.whl", hash = "sha256:11d8f3dd2b9c1207dcaf2ee0bbbfd5991f571186ec9cc78427ba5bd32afae4b5"}, + {file = "PyYAML-6.0.2-cp311-cp311-win_amd64.whl", hash = "sha256:e10ce637b18caea04431ce14fabcf5c64a1c61ec9c56b071a4b7ca131ca52d44"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:c70c95198c015b85feafc136515252a261a84561b7b1d51e3384e0655ddf25ab"}, + {file = "PyYAML-6.0.2-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:ce826d6ef20b1bc864f0a68340c8b3287705cae2f8b4b1d932177dcc76721725"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1f71ea527786de97d1a0cc0eacd1defc0985dcf6b3f17bb77dcfc8c34bec4dc5"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9b22676e8097e9e22e36d6b7bda33190d0d400f345f23d4065d48f4ca7ae0425"}, + {file = "PyYAML-6.0.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:80bab7bfc629882493af4aa31a4cfa43a4c57c83813253626916b8c7ada83476"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:0833f8694549e586547b576dcfaba4a6b55b9e96098b36cdc7ebefe667dfed48"}, + {file = "PyYAML-6.0.2-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8b9c7197f7cb2738065c481a0461e50ad02f18c78cd75775628afb4d7137fb3b"}, + {file = "PyYAML-6.0.2-cp312-cp312-win32.whl", hash = "sha256:ef6107725bd54b262d6dedcc2af448a266975032bc85ef0172c5f059da6325b4"}, + {file = "PyYAML-6.0.2-cp312-cp312-win_amd64.whl", hash = "sha256:7e7401d0de89a9a855c839bc697c079a4af81cf878373abd7dc625847d25cbd8"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:efdca5630322a10774e8e98e1af481aad470dd62c3170801852d752aa7a783ba"}, + {file = "PyYAML-6.0.2-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:50187695423ffe49e2deacb8cd10510bc361faac997de9efef88badc3bb9e2d1"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0ffe8360bab4910ef1b9e87fb812d8bc0a308b0d0eef8c8f44e0254ab3b07133"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:17e311b6c678207928d649faa7cb0d7b4c26a0ba73d41e99c4fff6b6c3276484"}, + {file = "PyYAML-6.0.2-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:70b189594dbe54f75ab3a1acec5f1e3faa7e8cf2f1e08d9b561cb41b845f69d5"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_aarch64.whl", hash = "sha256:41e4e3953a79407c794916fa277a82531dd93aad34e29c2a514c2c0c5fe971cc"}, + {file = "PyYAML-6.0.2-cp313-cp313-musllinux_1_1_x86_64.whl", hash = "sha256:68ccc6023a3400877818152ad9a1033e3db8625d899c72eacb5a668902e4d652"}, + {file = "PyYAML-6.0.2-cp313-cp313-win32.whl", hash = "sha256:bc2fa7c6b47d6bc618dd7fb02ef6fdedb1090ec036abab80d4681424b84c1183"}, + {file = "PyYAML-6.0.2-cp313-cp313-win_amd64.whl", hash = "sha256:8388ee1976c416731879ac16da0aff3f63b286ffdd57cdeb95f3f2e085687563"}, + {file = "PyYAML-6.0.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:24471b829b3bf607e04e88d79542a9d48bb037c2267d7927a874e6c205ca7e9a"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7fded462629cfa4b685c5416b949ebad6cec74af5e2d42905d41e257e0869f5"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d84a1718ee396f54f3a086ea0a66d8e552b2ab2017ef8b420e92edbc841c352d"}, + {file = "PyYAML-6.0.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9056c1ecd25795207ad294bcf39f2db3d845767be0ea6e6a34d856f006006083"}, + {file = "PyYAML-6.0.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:82d09873e40955485746739bcb8b4586983670466c23382c19cffecbf1fd8706"}, + {file = "PyYAML-6.0.2-cp38-cp38-win32.whl", hash = "sha256:43fa96a3ca0d6b1812e01ced1044a003533c47f6ee8aca31724f78e93ccc089a"}, + {file = "PyYAML-6.0.2-cp38-cp38-win_amd64.whl", hash = "sha256:01179a4a8559ab5de078078f37e5c1a30d76bb88519906844fd7bdea1b7729ff"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:688ba32a1cffef67fd2e9398a2efebaea461578b0923624778664cc1c914db5d"}, + {file = "PyYAML-6.0.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a8786accb172bd8afb8be14490a16625cbc387036876ab6ba70912730faf8e1f"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d8e03406cac8513435335dbab54c0d385e4a49e4945d2909a581c83647ca0290"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:f753120cb8181e736c57ef7636e83f31b9c0d1722c516f7e86cf15b7aa57ff12"}, + {file = "PyYAML-6.0.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:3b1fdb9dc17f5a7677423d508ab4f243a726dea51fa5e70992e59a7411c89d19"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:0b69e4ce7a131fe56b7e4d770c67429700908fc0752af059838b1cfb41960e4e"}, + {file = "PyYAML-6.0.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:a9f8c2e67970f13b16084e04f134610fd1d374bf477b17ec1599185cf611d725"}, + {file = "PyYAML-6.0.2-cp39-cp39-win32.whl", hash = "sha256:6395c297d42274772abc367baaa79683958044e5d3835486c16da75d2a694631"}, + {file = "PyYAML-6.0.2-cp39-cp39-win_amd64.whl", hash = "sha256:39693e1f8320ae4f43943590b49779ffb98acb81f788220ea932a6b6c51004d8"}, + {file = "pyyaml-6.0.2.tar.gz", hash = "sha256:d584d9ec91ad65861cc08d42e834324ef890a082e591037abe114850ff7bbc3e"}, ] [[package]] name = "regex" -version = "2023.6.3" +version = "2024.7.24" description = "Alternative regular expression module, to replace re." optional = false -python-versions = ">=3.6" +python-versions = ">=3.8" files = [ - {file = "regex-2023.6.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:824bf3ac11001849aec3fa1d69abcb67aac3e150a933963fb12bda5151fe1bfd"}, - {file = "regex-2023.6.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:05ed27acdf4465c95826962528f9e8d41dbf9b1aa8531a387dee6ed215a3e9ef"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0b49c764f88a79160fa64f9a7b425620e87c9f46095ef9c9920542ab2495c8bc"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8e3f1316c2293e5469f8f09dc2d76efb6c3982d3da91ba95061a7e69489a14ef"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:43e1dd9d12df9004246bacb79a0e5886b3b6071b32e41f83b0acbf293f820ee8"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4959e8bcbfda5146477d21c3a8ad81b185cd252f3d0d6e4724a5ef11c012fb06"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:af4dd387354dc83a3bff67127a124c21116feb0d2ef536805c454721c5d7993d"}, - {file = "regex-2023.6.3-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2239d95d8e243658b8dbb36b12bd10c33ad6e6933a54d36ff053713f129aa536"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:890e5a11c97cf0d0c550eb661b937a1e45431ffa79803b942a057c4fb12a2da2"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a8105e9af3b029f243ab11ad47c19b566482c150c754e4c717900a798806b222"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:25be746a8ec7bc7b082783216de8e9473803706723b3f6bef34b3d0ed03d57e2"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:3676f1dd082be28b1266c93f618ee07741b704ab7b68501a173ce7d8d0d0ca18"}, - {file = "regex-2023.6.3-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:10cb847aeb1728412c666ab2e2000ba6f174f25b2bdc7292e7dd71b16db07568"}, - {file = "regex-2023.6.3-cp310-cp310-win32.whl", hash = "sha256:dbbbfce33cd98f97f6bffb17801b0576e653f4fdb1d399b2ea89638bc8d08ae1"}, - {file = "regex-2023.6.3-cp310-cp310-win_amd64.whl", hash = "sha256:c5f8037000eb21e4823aa485149f2299eb589f8d1fe4b448036d230c3f4e68e0"}, - {file = "regex-2023.6.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c123f662be8ec5ab4ea72ea300359023a5d1df095b7ead76fedcd8babbedf969"}, - {file = "regex-2023.6.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:9edcbad1f8a407e450fbac88d89e04e0b99a08473f666a3f3de0fd292badb6aa"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dcba6dae7de533c876255317c11f3abe4907ba7d9aa15d13e3d9710d4315ec0e"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:29cdd471ebf9e0f2fb3cac165efedc3c58db841d83a518b082077e612d3ee5df"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:12b74fbbf6cbbf9dbce20eb9b5879469e97aeeaa874145517563cca4029db65c"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:0c29ca1bd61b16b67be247be87390ef1d1ef702800f91fbd1991f5c4421ebae8"}, - {file = "regex-2023.6.3-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d77f09bc4b55d4bf7cc5eba785d87001d6757b7c9eec237fe2af57aba1a071d9"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:ea353ecb6ab5f7e7d2f4372b1e779796ebd7b37352d290096978fea83c4dba0c"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:10590510780b7541969287512d1b43f19f965c2ece6c9b1c00fc367b29d8dce7"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:e2fbd6236aae3b7f9d514312cdb58e6494ee1c76a9948adde6eba33eb1c4264f"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:6b2675068c8b56f6bfd5a2bda55b8accbb96c02fd563704732fd1c95e2083461"}, - {file = "regex-2023.6.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:74419d2b50ecb98360cfaa2974da8689cb3b45b9deff0dcf489c0d333bcc1477"}, - {file = "regex-2023.6.3-cp311-cp311-win32.whl", hash = "sha256:fb5ec16523dc573a4b277663a2b5a364e2099902d3944c9419a40ebd56a118f9"}, - {file = "regex-2023.6.3-cp311-cp311-win_amd64.whl", hash = "sha256:09e4a1a6acc39294a36b7338819b10baceb227f7f7dbbea0506d419b5a1dd8af"}, - {file = "regex-2023.6.3-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:0654bca0cdf28a5956c83839162692725159f4cda8d63e0911a2c0dc76166525"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:463b6a3ceb5ca952e66550a4532cef94c9a0c80dc156c4cc343041951aec1697"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:87b2a5bb5e78ee0ad1de71c664d6eb536dc3947a46a69182a90f4410f5e3f7dd"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:6343c6928282c1f6a9db41f5fd551662310e8774c0e5ebccb767002fcf663ca9"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:b6192d5af2ccd2a38877bfef086d35e6659566a335b1492786ff254c168b1693"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:74390d18c75054947e4194019077e243c06fbb62e541d8817a0fa822ea310c14"}, - {file = "regex-2023.6.3-cp36-cp36m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:742e19a90d9bb2f4a6cf2862b8b06dea5e09b96c9f2df1779e53432d7275331f"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:8abbc5d54ea0ee80e37fef009e3cec5dafd722ed3c829126253d3e22f3846f1e"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:c2b867c17a7a7ae44c43ebbeb1b5ff406b3e8d5b3e14662683e5e66e6cc868d3"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:d831c2f8ff278179705ca59f7e8524069c1a989e716a1874d6d1aab6119d91d1"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_s390x.whl", hash = "sha256:ee2d1a9a253b1729bb2de27d41f696ae893507c7db224436abe83ee25356f5c1"}, - {file = "regex-2023.6.3-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:61474f0b41fe1a80e8dfa70f70ea1e047387b7cd01c85ec88fa44f5d7561d787"}, - {file = "regex-2023.6.3-cp36-cp36m-win32.whl", hash = "sha256:0b71e63226e393b534105fcbdd8740410dc6b0854c2bfa39bbda6b0d40e59a54"}, - {file = "regex-2023.6.3-cp36-cp36m-win_amd64.whl", hash = "sha256:bbb02fd4462f37060122e5acacec78e49c0fbb303c30dd49c7f493cf21fc5b27"}, - {file = "regex-2023.6.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:b862c2b9d5ae38a68b92e215b93f98d4c5e9454fa36aae4450f61dd33ff48487"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:976d7a304b59ede34ca2921305b57356694f9e6879db323fd90a80f865d355a3"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:83320a09188e0e6c39088355d423aa9d056ad57a0b6c6381b300ec1a04ec3d16"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:9427a399501818a7564f8c90eced1e9e20709ece36be701f394ada99890ea4b3"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7178bbc1b2ec40eaca599d13c092079bf529679bf0371c602edaa555e10b41c3"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:837328d14cde912af625d5f303ec29f7e28cdab588674897baafaf505341f2fc"}, - {file = "regex-2023.6.3-cp37-cp37m-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:2d44dc13229905ae96dd2ae2dd7cebf824ee92bc52e8cf03dcead37d926da019"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d54af539295392611e7efbe94e827311eb8b29668e2b3f4cadcfe6f46df9c777"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:7117d10690c38a622e54c432dfbbd3cbd92f09401d622902c32f6d377e2300ee"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:bb60b503ec8a6e4e3e03a681072fa3a5adcbfa5479fa2d898ae2b4a8e24c4591"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:65ba8603753cec91c71de423a943ba506363b0e5c3fdb913ef8f9caa14b2c7e0"}, - {file = "regex-2023.6.3-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:271f0bdba3c70b58e6f500b205d10a36fb4b58bd06ac61381b68de66442efddb"}, - {file = "regex-2023.6.3-cp37-cp37m-win32.whl", hash = "sha256:9beb322958aaca059f34975b0df135181f2e5d7a13b84d3e0e45434749cb20f7"}, - {file = "regex-2023.6.3-cp37-cp37m-win_amd64.whl", hash = "sha256:fea75c3710d4f31389eed3c02f62d0b66a9da282521075061ce875eb5300cf23"}, - {file = "regex-2023.6.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8f56fcb7ff7bf7404becdfc60b1e81a6d0561807051fd2f1860b0d0348156a07"}, - {file = "regex-2023.6.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:d2da3abc88711bce7557412310dfa50327d5769a31d1c894b58eb256459dc289"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a99b50300df5add73d307cf66abea093304a07eb017bce94f01e795090dea87c"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5708089ed5b40a7b2dc561e0c8baa9535b77771b64a8330b684823cfd5116036"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:687ea9d78a4b1cf82f8479cab23678aff723108df3edeac098e5b2498879f4a7"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:4d3850beab9f527f06ccc94b446c864059c57651b3f911fddb8d9d3ec1d1b25d"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e8915cc96abeb8983cea1df3c939e3c6e1ac778340c17732eb63bb96247b91d2"}, - {file = "regex-2023.6.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:841d6e0e5663d4c7b4c8099c9997be748677d46cbf43f9f471150e560791f7ff"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:9edce5281f965cf135e19840f4d93d55b3835122aa76ccacfd389e880ba4cf82"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:b956231ebdc45f5b7a2e1f90f66a12be9610ce775fe1b1d50414aac1e9206c06"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:36efeba71c6539d23c4643be88295ce8c82c88bbd7c65e8a24081d2ca123da3f"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:cf67ca618b4fd34aee78740bea954d7c69fdda419eb208c2c0c7060bb822d747"}, - {file = "regex-2023.6.3-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:b4598b1897837067a57b08147a68ac026c1e73b31ef6e36deeeb1fa60b2933c9"}, - {file = "regex-2023.6.3-cp38-cp38-win32.whl", hash = "sha256:f415f802fbcafed5dcc694c13b1292f07fe0befdb94aa8a52905bd115ff41e88"}, - {file = "regex-2023.6.3-cp38-cp38-win_amd64.whl", hash = "sha256:d4f03bb71d482f979bda92e1427f3ec9b220e62a7dd337af0aa6b47bf4498f72"}, - {file = "regex-2023.6.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ccf91346b7bd20c790310c4147eee6ed495a54ddb6737162a36ce9dbef3e4751"}, - {file = "regex-2023.6.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:b28f5024a3a041009eb4c333863d7894d191215b39576535c6734cd88b0fcb68"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e0bb18053dfcfed432cc3ac632b5e5e5c5b7e55fb3f8090e867bfd9b054dbcbf"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9a5bfb3004f2144a084a16ce19ca56b8ac46e6fd0651f54269fc9e230edb5e4a"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:5c6b48d0fa50d8f4df3daf451be7f9689c2bde1a52b1225c5926e3f54b6a9ed1"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:051da80e6eeb6e239e394ae60704d2b566aa6a7aed6f2890a7967307267a5dc6"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a4c3b7fa4cdaa69268748665a1a6ff70c014d39bb69c50fda64b396c9116cf77"}, - {file = "regex-2023.6.3-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:457b6cce21bee41ac292d6753d5e94dcbc5c9e3e3a834da285b0bde7aa4a11e9"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:aad51907d74fc183033ad796dd4c2e080d1adcc4fd3c0fd4fd499f30c03011cd"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:0385e73da22363778ef2324950e08b689abdf0b108a7d8decb403ad7f5191938"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:c6a57b742133830eec44d9b2290daf5cbe0a2f1d6acee1b3c7b1c7b2f3606df7"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:3e5219bf9e75993d73ab3d25985c857c77e614525fac9ae02b1bebd92f7cecac"}, - {file = "regex-2023.6.3-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:e5087a3c59eef624a4591ef9eaa6e9a8d8a94c779dade95d27c0bc24650261cd"}, - {file = "regex-2023.6.3-cp39-cp39-win32.whl", hash = "sha256:20326216cc2afe69b6e98528160b225d72f85ab080cbdf0b11528cbbaba2248f"}, - {file = "regex-2023.6.3-cp39-cp39-win_amd64.whl", hash = "sha256:bdff5eab10e59cf26bc479f565e25ed71a7d041d1ded04ccf9aee1d9f208487a"}, - {file = "regex-2023.6.3.tar.gz", hash = "sha256:72d1a25bf36d2050ceb35b517afe13864865268dfb45910e2e17a84be6cbfeb0"}, + {file = "regex-2024.7.24-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:228b0d3f567fafa0633aee87f08b9276c7062da9616931382993c03808bb68ce"}, + {file = "regex-2024.7.24-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:3426de3b91d1bc73249042742f45c2148803c111d1175b283270177fdf669024"}, + {file = "regex-2024.7.24-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:f273674b445bcb6e4409bf8d1be67bc4b58e8b46fd0d560055d515b8830063cd"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:23acc72f0f4e1a9e6e9843d6328177ae3074b4182167e34119ec7233dfeccf53"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:65fd3d2e228cae024c411c5ccdffae4c315271eee4a8b839291f84f796b34eca"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c414cbda77dbf13c3bc88b073a1a9f375c7b0cb5e115e15d4b73ec3a2fbc6f59"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf7a89eef64b5455835f5ed30254ec19bf41f7541cd94f266ab7cbd463f00c41"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:19c65b00d42804e3fbea9708f0937d157e53429a39b7c61253ff15670ff62cb5"}, + {file = "regex-2024.7.24-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:7a5486ca56c8869070a966321d5ab416ff0f83f30e0e2da1ab48815c8d165d46"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:6f51f9556785e5a203713f5efd9c085b4a45aecd2a42573e2b5041881b588d1f"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:a4997716674d36a82eab3e86f8fa77080a5d8d96a389a61ea1d0e3a94a582cf7"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:c0abb5e4e8ce71a61d9446040c1e86d4e6d23f9097275c5bd49ed978755ff0fe"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_s390x.whl", hash = "sha256:18300a1d78cf1290fa583cd8b7cde26ecb73e9f5916690cf9d42de569c89b1ce"}, + {file = "regex-2024.7.24-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:416c0e4f56308f34cdb18c3f59849479dde5b19febdcd6e6fa4d04b6c31c9faa"}, + {file = "regex-2024.7.24-cp310-cp310-win32.whl", hash = "sha256:fb168b5924bef397b5ba13aabd8cf5df7d3d93f10218d7b925e360d436863f66"}, + {file = "regex-2024.7.24-cp310-cp310-win_amd64.whl", hash = "sha256:6b9fc7e9cc983e75e2518496ba1afc524227c163e43d706688a6bb9eca41617e"}, + {file = "regex-2024.7.24-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:382281306e3adaaa7b8b9ebbb3ffb43358a7bbf585fa93821300a418bb975281"}, + {file = "regex-2024.7.24-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4fdd1384619f406ad9037fe6b6eaa3de2749e2e12084abc80169e8e075377d3b"}, + {file = "regex-2024.7.24-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:3d974d24edb231446f708c455fd08f94c41c1ff4f04bcf06e5f36df5ef50b95a"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a2ec4419a3fe6cf8a4795752596dfe0adb4aea40d3683a132bae9c30b81e8d73"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:eb563dd3aea54c797adf513eeec819c4213d7dbfc311874eb4fd28d10f2ff0f2"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:45104baae8b9f67569f0f1dca5e1f1ed77a54ae1cd8b0b07aba89272710db61e"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:994448ee01864501912abf2bad9203bffc34158e80fe8bfb5b031f4f8e16da51"}, + {file = "regex-2024.7.24-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3fac296f99283ac232d8125be932c5cd7644084a30748fda013028c815ba3364"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:7e37e809b9303ec3a179085415cb5f418ecf65ec98cdfe34f6a078b46ef823ee"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:01b689e887f612610c869421241e075c02f2e3d1ae93a037cb14f88ab6a8934c"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:f6442f0f0ff81775eaa5b05af8a0ffa1dda36e9cf6ec1e0d3d245e8564b684ce"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_s390x.whl", hash = "sha256:871e3ab2838fbcb4e0865a6e01233975df3a15e6fce93b6f99d75cacbd9862d1"}, + {file = "regex-2024.7.24-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:c918b7a1e26b4ab40409820ddccc5d49871a82329640f5005f73572d5eaa9b5e"}, + {file = "regex-2024.7.24-cp311-cp311-win32.whl", hash = "sha256:2dfbb8baf8ba2c2b9aa2807f44ed272f0913eeeba002478c4577b8d29cde215c"}, + {file = "regex-2024.7.24-cp311-cp311-win_amd64.whl", hash = "sha256:538d30cd96ed7d1416d3956f94d54e426a8daf7c14527f6e0d6d425fcb4cca52"}, + {file = "regex-2024.7.24-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:fe4ebef608553aff8deb845c7f4f1d0740ff76fa672c011cc0bacb2a00fbde86"}, + {file = "regex-2024.7.24-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:74007a5b25b7a678459f06559504f1eec2f0f17bca218c9d56f6a0a12bfffdad"}, + {file = "regex-2024.7.24-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:7df9ea48641da022c2a3c9c641650cd09f0cd15e8908bf931ad538f5ca7919c9"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:6a1141a1dcc32904c47f6846b040275c6e5de0bf73f17d7a409035d55b76f289"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:80c811cfcb5c331237d9bad3bea2c391114588cf4131707e84d9493064d267f9"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:7214477bf9bd195894cf24005b1e7b496f46833337b5dedb7b2a6e33f66d962c"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d55588cba7553f0b6ec33130bc3e114b355570b45785cebdc9daed8c637dd440"}, + {file = "regex-2024.7.24-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:558a57cfc32adcf19d3f791f62b5ff564922942e389e3cfdb538a23d65a6b610"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:a512eed9dfd4117110b1881ba9a59b31433caed0c4101b361f768e7bcbaf93c5"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:86b17ba823ea76256b1885652e3a141a99a5c4422f4a869189db328321b73799"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:5eefee9bfe23f6df09ffb6dfb23809f4d74a78acef004aa904dc7c88b9944b05"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_s390x.whl", hash = "sha256:731fcd76bbdbf225e2eb85b7c38da9633ad3073822f5ab32379381e8c3c12e94"}, + {file = "regex-2024.7.24-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:eaef80eac3b4cfbdd6de53c6e108b4c534c21ae055d1dbea2de6b3b8ff3def38"}, + {file = "regex-2024.7.24-cp312-cp312-win32.whl", hash = "sha256:185e029368d6f89f36e526764cf12bf8d6f0e3a2a7737da625a76f594bdfcbfc"}, + {file = "regex-2024.7.24-cp312-cp312-win_amd64.whl", hash = "sha256:2f1baff13cc2521bea83ab2528e7a80cbe0ebb2c6f0bfad15be7da3aed443908"}, + {file = "regex-2024.7.24-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:66b4c0731a5c81921e938dcf1a88e978264e26e6ac4ec96a4d21ae0354581ae0"}, + {file = "regex-2024.7.24-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:88ecc3afd7e776967fa16c80f974cb79399ee8dc6c96423321d6f7d4b881c92b"}, + {file = "regex-2024.7.24-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:64bd50cf16bcc54b274e20235bf8edbb64184a30e1e53873ff8d444e7ac656b2"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:eb462f0e346fcf41a901a126b50f8781e9a474d3927930f3490f38a6e73b6950"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:a82465ebbc9b1c5c50738536fdfa7cab639a261a99b469c9d4c7dcbb2b3f1e57"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:68a8f8c046c6466ac61a36b65bb2395c74451df2ffb8458492ef49900efed293"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:dac8e84fff5d27420f3c1e879ce9929108e873667ec87e0c8eeb413a5311adfe"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba2537ef2163db9e6ccdbeb6f6424282ae4dea43177402152c67ef869cf3978b"}, + {file = "regex-2024.7.24-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:43affe33137fcd679bdae93fb25924979517e011f9dea99163f80b82eadc7e53"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:c9bb87fdf2ab2370f21e4d5636e5317775e5d51ff32ebff2cf389f71b9b13750"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:945352286a541406f99b2655c973852da7911b3f4264e010218bbc1cc73168f2"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:8bc593dcce679206b60a538c302d03c29b18e3d862609317cb560e18b66d10cf"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_s390x.whl", hash = "sha256:3f3b6ca8eae6d6c75a6cff525c8530c60e909a71a15e1b731723233331de4169"}, + {file = "regex-2024.7.24-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:c51edc3541e11fbe83f0c4d9412ef6c79f664a3745fab261457e84465ec9d5a8"}, + {file = "regex-2024.7.24-cp38-cp38-win32.whl", hash = "sha256:d0a07763776188b4db4c9c7fb1b8c494049f84659bb387b71c73bbc07f189e96"}, + {file = "regex-2024.7.24-cp38-cp38-win_amd64.whl", hash = "sha256:8fd5afd101dcf86a270d254364e0e8dddedebe6bd1ab9d5f732f274fa00499a5"}, + {file = "regex-2024.7.24-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:0ffe3f9d430cd37d8fa5632ff6fb36d5b24818c5c986893063b4e5bdb84cdf24"}, + {file = "regex-2024.7.24-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:25419b70ba00a16abc90ee5fce061228206173231f004437730b67ac77323f0d"}, + {file = "regex-2024.7.24-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:33e2614a7ce627f0cdf2ad104797d1f68342d967de3695678c0cb84f530709f8"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d33a0021893ede5969876052796165bab6006559ab845fd7b515a30abdd990dc"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:04ce29e2c5fedf296b1a1b0acc1724ba93a36fb14031f3abfb7abda2806c1535"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b16582783f44fbca6fcf46f61347340c787d7530d88b4d590a397a47583f31dd"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:836d3cc225b3e8a943d0b02633fb2f28a66e281290302a79df0e1eaa984ff7c1"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:438d9f0f4bc64e8dea78274caa5af971ceff0f8771e1a2333620969936ba10be"}, + {file = "regex-2024.7.24-cp39-cp39-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_12_x86_64.manylinux2010_x86_64.whl", hash = "sha256:973335b1624859cb0e52f96062a28aa18f3a5fc77a96e4a3d6d76e29811a0e6e"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:c5e69fd3eb0b409432b537fe3c6f44ac089c458ab6b78dcec14478422879ec5f"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:fbf8c2f00904eaf63ff37718eb13acf8e178cb940520e47b2f05027f5bb34ce3"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:ae2757ace61bc4061b69af19e4689fa4416e1a04840f33b441034202b5cd02d4"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_s390x.whl", hash = "sha256:44fc61b99035fd9b3b9453f1713234e5a7c92a04f3577252b45feefe1b327759"}, + {file = "regex-2024.7.24-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:84c312cdf839e8b579f504afcd7b65f35d60b6285d892b19adea16355e8343c9"}, + {file = "regex-2024.7.24-cp39-cp39-win32.whl", hash = "sha256:ca5b2028c2f7af4e13fb9fc29b28d0ce767c38c7facdf64f6c2cd040413055f1"}, + {file = "regex-2024.7.24-cp39-cp39-win_amd64.whl", hash = "sha256:7c479f5ae937ec9985ecaf42e2e10631551d909f203e31308c12d703922742f9"}, + {file = "regex-2024.7.24.tar.gz", hash = "sha256:9cfd009eed1a46b27c14039ad5bbc5e71b6367c5b2e6d5f5da0ea91600817506"}, ] [[package]] name = "requests" -version = "2.31.0" +version = "2.32.3" description = "Python HTTP for Humans." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "requests-2.31.0-py3-none-any.whl", hash = "sha256:58cd2187c01e70e6e26505bca751777aa9f2ee0b7f4300988b709f44e013003f"}, - {file = "requests-2.31.0.tar.gz", hash = "sha256:942c5a758f98d790eaed1a29cb6eefc7ffb0d1cf7af05c3d2791656dbd6ad1e1"}, + {file = "requests-2.32.3-py3-none-any.whl", hash = "sha256:70761cfe03c773ceb22aa2f671b4757976145175cdfca038c02654d061d6dcc6"}, + {file = "requests-2.32.3.tar.gz", hash = "sha256:55365417734eb18255590a9ff9eb97e9e1da868d4ccd6402399eaf68af20a760"}, ] [package.dependencies] @@ -1028,6 +1332,17 @@ urllib3 = ">=1.21.1,<3" socks = ["PySocks (>=1.5.6,!=1.5.7)"] use-chardet-on-py3 = ["chardet (>=3.0.2,<6)"] +[[package]] +name = "sniffio" +version = "1.3.1" +description = "Sniff out which async library your code is running under" +optional = false +python-versions = ">=3.7" +files = [ + {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, + {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, +] + [[package]] name = "toml" version = "0.10.2" @@ -1051,187 +1366,157 @@ files = [ ] [[package]] -name = "typed-ast" -version = "1.5.5" -description = "a fork of Python 2 and 3 ast modules with type comment support" +name = "tqdm" +version = "4.66.5" +description = "Fast, Extensible Progress Meter" optional = false -python-versions = ">=3.6" +python-versions = ">=3.7" files = [ - {file = "typed_ast-1.5.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:4bc1efe0ce3ffb74784e06460f01a223ac1f6ab31c6bc0376a21184bf5aabe3b"}, - {file = "typed_ast-1.5.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:5f7a8c46a8b333f71abd61d7ab9255440d4a588f34a21f126bbfc95f6049e686"}, - {file = "typed_ast-1.5.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:597fc66b4162f959ee6a96b978c0435bd63791e31e4f410622d19f1686d5e769"}, - {file = "typed_ast-1.5.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d41b7a686ce653e06c2609075d397ebd5b969d821b9797d029fccd71fdec8e04"}, - {file = "typed_ast-1.5.5-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:5fe83a9a44c4ce67c796a1b466c270c1272e176603d5e06f6afbc101a572859d"}, - {file = "typed_ast-1.5.5-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:d5c0c112a74c0e5db2c75882a0adf3133adedcdbfd8cf7c9d6ed77365ab90a1d"}, - {file = "typed_ast-1.5.5-cp310-cp310-win_amd64.whl", hash = "sha256:e1a976ed4cc2d71bb073e1b2a250892a6e968ff02aa14c1f40eba4f365ffec02"}, - {file = "typed_ast-1.5.5-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c631da9710271cb67b08bd3f3813b7af7f4c69c319b75475436fcab8c3d21bee"}, - {file = "typed_ast-1.5.5-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b445c2abfecab89a932b20bd8261488d574591173d07827c1eda32c457358b18"}, - {file = "typed_ast-1.5.5-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cc95ffaaab2be3b25eb938779e43f513e0e538a84dd14a5d844b8f2932593d88"}, - {file = "typed_ast-1.5.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:61443214d9b4c660dcf4b5307f15c12cb30bdfe9588ce6158f4a005baeb167b2"}, - {file = "typed_ast-1.5.5-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:6eb936d107e4d474940469e8ec5b380c9b329b5f08b78282d46baeebd3692dc9"}, - {file = "typed_ast-1.5.5-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e48bf27022897577d8479eaed64701ecaf0467182448bd95759883300ca818c8"}, - {file = "typed_ast-1.5.5-cp311-cp311-win_amd64.whl", hash = "sha256:83509f9324011c9a39faaef0922c6f720f9623afe3fe220b6d0b15638247206b"}, - {file = "typed_ast-1.5.5-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:44f214394fc1af23ca6d4e9e744804d890045d1643dd7e8229951e0ef39429b5"}, - {file = "typed_ast-1.5.5-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:118c1ce46ce58fda78503eae14b7664163aa735b620b64b5b725453696f2a35c"}, - {file = "typed_ast-1.5.5-cp36-cp36m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:be4919b808efa61101456e87f2d4c75b228f4e52618621c77f1ddcaae15904fa"}, - {file = "typed_ast-1.5.5-cp36-cp36m-musllinux_1_1_aarch64.whl", hash = "sha256:fc2b8c4e1bc5cd96c1a823a885e6b158f8451cf6f5530e1829390b4d27d0807f"}, - {file = "typed_ast-1.5.5-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:16f7313e0a08c7de57f2998c85e2a69a642e97cb32f87eb65fbfe88381a5e44d"}, - {file = "typed_ast-1.5.5-cp36-cp36m-win_amd64.whl", hash = "sha256:2b946ef8c04f77230489f75b4b5a4a6f24c078be4aed241cfabe9cbf4156e7e5"}, - {file = "typed_ast-1.5.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:2188bc33d85951ea4ddad55d2b35598b2709d122c11c75cffd529fbc9965508e"}, - {file = "typed_ast-1.5.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:0635900d16ae133cab3b26c607586131269f88266954eb04ec31535c9a12ef1e"}, - {file = "typed_ast-1.5.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:57bfc3cf35a0f2fdf0a88a3044aafaec1d2f24d8ae8cd87c4f58d615fb5b6311"}, - {file = "typed_ast-1.5.5-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:fe58ef6a764de7b4b36edfc8592641f56e69b7163bba9f9c8089838ee596bfb2"}, - {file = "typed_ast-1.5.5-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:d09d930c2d1d621f717bb217bf1fe2584616febb5138d9b3e8cdd26506c3f6d4"}, - {file = "typed_ast-1.5.5-cp37-cp37m-win_amd64.whl", hash = "sha256:d40c10326893ecab8a80a53039164a224984339b2c32a6baf55ecbd5b1df6431"}, - {file = "typed_ast-1.5.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fd946abf3c31fb50eee07451a6aedbfff912fcd13cf357363f5b4e834cc5e71a"}, - {file = "typed_ast-1.5.5-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:ed4a1a42df8a3dfb6b40c3d2de109e935949f2f66b19703eafade03173f8f437"}, - {file = "typed_ast-1.5.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:045f9930a1550d9352464e5149710d56a2aed23a2ffe78946478f7b5416f1ede"}, - {file = "typed_ast-1.5.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:381eed9c95484ceef5ced626355fdc0765ab51d8553fec08661dce654a935db4"}, - {file = "typed_ast-1.5.5-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:bfd39a41c0ef6f31684daff53befddae608f9daf6957140228a08e51f312d7e6"}, - {file = "typed_ast-1.5.5-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8c524eb3024edcc04e288db9541fe1f438f82d281e591c548903d5b77ad1ddd4"}, - {file = "typed_ast-1.5.5-cp38-cp38-win_amd64.whl", hash = "sha256:7f58fabdde8dcbe764cef5e1a7fcb440f2463c1bbbec1cf2a86ca7bc1f95184b"}, - {file = "typed_ast-1.5.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:042eb665ff6bf020dd2243307d11ed626306b82812aba21836096d229fdc6a10"}, - {file = "typed_ast-1.5.5-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:622e4a006472b05cf6ef7f9f2636edc51bda670b7bbffa18d26b255269d3d814"}, - {file = "typed_ast-1.5.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1efebbbf4604ad1283e963e8915daa240cb4bf5067053cf2f0baadc4d4fb51b8"}, - {file = "typed_ast-1.5.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f0aefdd66f1784c58f65b502b6cf8b121544680456d1cebbd300c2c813899274"}, - {file = "typed_ast-1.5.5-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:48074261a842acf825af1968cd912f6f21357316080ebaca5f19abbb11690c8a"}, - {file = "typed_ast-1.5.5-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:429ae404f69dc94b9361bb62291885894b7c6fb4640d561179548c849f8492ba"}, - {file = "typed_ast-1.5.5-cp39-cp39-win_amd64.whl", hash = "sha256:335f22ccb244da2b5c296e6f96b06ee9bed46526db0de38d2f0e5a6597b81155"}, - {file = "typed_ast-1.5.5.tar.gz", hash = "sha256:94282f7a354f36ef5dbce0ef3467ebf6a258e370ab33d5b40c249fa996e590dd"}, + {file = "tqdm-4.66.5-py3-none-any.whl", hash = "sha256:90279a3770753eafc9194a0364852159802111925aa30eb3f9d85b0e805ac7cd"}, + {file = "tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad"}, ] +[package.dependencies] +colorama = {version = "*", markers = "platform_system == \"Windows\""} + +[package.extras] +dev = ["pytest (>=6)", "pytest-cov", "pytest-timeout", "pytest-xdist"] +notebook = ["ipywidgets (>=6)"] +slack = ["slack-sdk"] +telegram = ["requests"] + [[package]] name = "typing-extensions" -version = "4.7.1" -description = "Backported and Experimental Type Hints for Python 3.7+" +version = "4.12.2" +description = "Backported and Experimental Type Hints for Python 3.8+" optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "typing_extensions-4.7.1-py3-none-any.whl", hash = "sha256:440d5dd3af93b060174bf433bccd69b0babc3b15b1a8dca43789fd7f61514b36"}, - {file = "typing_extensions-4.7.1.tar.gz", hash = "sha256:b75ddc264f0ba5615db7ba217daeb99701ad295353c45f9e95963337ceeeffb2"}, + {file = "typing_extensions-4.12.2-py3-none-any.whl", hash = "sha256:04e5ca0351e0f3f85c6853954072df659d0d13fac324d0072316b67d7794700d"}, + {file = "typing_extensions-4.12.2.tar.gz", hash = "sha256:1a7ead55c7e559dd4dee8856e3a88b41225abfe1ce8df57b7c13915fe121ffb8"}, ] [[package]] name = "urllib3" -version = "2.0.3" +version = "2.2.2" description = "HTTP library with thread-safe connection pooling, file post, and more." optional = false -python-versions = ">=3.7" +python-versions = ">=3.8" files = [ - {file = "urllib3-2.0.3-py3-none-any.whl", hash = "sha256:48e7fafa40319d358848e1bc6809b208340fafe2096f1725d05d67443d0483d1"}, - {file = "urllib3-2.0.3.tar.gz", hash = "sha256:bee28b5e56addb8226c96f7f13ac28cb4c301dd5ea8a6ca179c0b9835e032825"}, + {file = "urllib3-2.2.2-py3-none-any.whl", hash = "sha256:a448b2f64d686155468037e1ace9f2d2199776e17f0a46610480d311f73e3472"}, + {file = "urllib3-2.2.2.tar.gz", hash = "sha256:dd505485549a7a552833da5e6063639d0d177c04f23bc3864e41e5dc5f612168"}, ] [package.extras] brotli = ["brotli (>=1.0.9)", "brotlicffi (>=0.8.0)"] -secure = ["certifi", "cryptography (>=1.9)", "idna (>=2.0.0)", "pyopenssl (>=17.1.0)", "urllib3-secure-extra"] +h2 = ["h2 (>=4,<5)"] socks = ["pysocks (>=1.5.6,!=1.5.7,<2.0)"] zstd = ["zstandard (>=0.18.0)"] [[package]] name = "yarl" -version = "1.9.2" +version = "1.9.4" description = "Yet another URL library" optional = false python-versions = ">=3.7" files = [ - {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c2ad583743d16ddbdf6bb14b5cd76bf43b0d0006e918809d5d4ddf7bde8dd82"}, - {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:82aa6264b36c50acfb2424ad5ca537a2060ab6de158a5bd2a72a032cc75b9eb8"}, - {file = "yarl-1.9.2-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c0c77533b5ed4bcc38e943178ccae29b9bcf48ffd1063f5821192f23a1bd27b9"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ee4afac41415d52d53a9833ebae7e32b344be72835bbb589018c9e938045a560"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9bf345c3a4f5ba7f766430f97f9cc1320786f19584acc7086491f45524a551ac"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:2a96c19c52ff442a808c105901d0bdfd2e28575b3d5f82e2f5fd67e20dc5f4ea"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:891c0e3ec5ec881541f6c5113d8df0315ce5440e244a716b95f2525b7b9f3608"}, - {file = "yarl-1.9.2-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:c3a53ba34a636a256d767c086ceb111358876e1fb6b50dfc4d3f4951d40133d5"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:566185e8ebc0898b11f8026447eacd02e46226716229cea8db37496c8cdd26e0"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:2b0738fb871812722a0ac2154be1f049c6223b9f6f22eec352996b69775b36d4"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:32f1d071b3f362c80f1a7d322bfd7b2d11e33d2adf395cc1dd4df36c9c243095"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:e9fdc7ac0d42bc3ea78818557fab03af6181e076a2944f43c38684b4b6bed8e3"}, - {file = "yarl-1.9.2-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:56ff08ab5df8429901ebdc5d15941b59f6253393cb5da07b4170beefcf1b2528"}, - {file = "yarl-1.9.2-cp310-cp310-win32.whl", hash = "sha256:8ea48e0a2f931064469bdabca50c2f578b565fc446f302a79ba6cc0ee7f384d3"}, - {file = "yarl-1.9.2-cp310-cp310-win_amd64.whl", hash = "sha256:50f33040f3836e912ed16d212f6cc1efb3231a8a60526a407aeb66c1c1956dde"}, - {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:646d663eb2232d7909e6601f1a9107e66f9791f290a1b3dc7057818fe44fc2b6"}, - {file = "yarl-1.9.2-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:aff634b15beff8902d1f918012fc2a42e0dbae6f469fce134c8a0dc51ca423bb"}, - {file = "yarl-1.9.2-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:a83503934c6273806aed765035716216cc9ab4e0364f7f066227e1aaea90b8d0"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:b25322201585c69abc7b0e89e72790469f7dad90d26754717f3310bfe30331c2"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:22a94666751778629f1ec4280b08eb11815783c63f52092a5953faf73be24191"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:8ec53a0ea2a80c5cd1ab397925f94bff59222aa3cf9c6da938ce05c9ec20428d"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:159d81f22d7a43e6eabc36d7194cb53f2f15f498dbbfa8edc8a3239350f59fe7"}, - {file = "yarl-1.9.2-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:832b7e711027c114d79dffb92576acd1bd2decc467dec60e1cac96912602d0e6"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:95d2ecefbcf4e744ea952d073c6922e72ee650ffc79028eb1e320e732898d7e8"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:d4e2c6d555e77b37288eaf45b8f60f0737c9efa3452c6c44626a5455aeb250b9"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:783185c75c12a017cc345015ea359cc801c3b29a2966c2655cd12b233bf5a2be"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:b8cc1863402472f16c600e3e93d542b7e7542a540f95c30afd472e8e549fc3f7"}, - {file = "yarl-1.9.2-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:822b30a0f22e588b32d3120f6d41e4ed021806418b4c9f0bc3048b8c8cb3f92a"}, - {file = "yarl-1.9.2-cp311-cp311-win32.whl", hash = "sha256:a60347f234c2212a9f0361955007fcf4033a75bf600a33c88a0a8e91af77c0e8"}, - {file = "yarl-1.9.2-cp311-cp311-win_amd64.whl", hash = "sha256:be6b3fdec5c62f2a67cb3f8c6dbf56bbf3f61c0f046f84645cd1ca73532ea051"}, - {file = "yarl-1.9.2-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:38a3928ae37558bc1b559f67410df446d1fbfa87318b124bf5032c31e3447b74"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ac9bb4c5ce3975aeac288cfcb5061ce60e0d14d92209e780c93954076c7c4367"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:3da8a678ca8b96c8606bbb8bfacd99a12ad5dd288bc6f7979baddd62f71c63ef"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:13414591ff516e04fcdee8dc051c13fd3db13b673c7a4cb1350e6b2ad9639ad3"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bf74d08542c3a9ea97bb8f343d4fcbd4d8f91bba5ec9d5d7f792dbe727f88938"}, - {file = "yarl-1.9.2-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:6e7221580dc1db478464cfeef9b03b95c5852cc22894e418562997df0d074ccc"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:494053246b119b041960ddcd20fd76224149cfea8ed8777b687358727911dd33"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:52a25809fcbecfc63ac9ba0c0fb586f90837f5425edfd1ec9f3372b119585e45"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:e65610c5792870d45d7b68c677681376fcf9cc1c289f23e8e8b39c1485384185"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:1b1bba902cba32cdec51fca038fd53f8beee88b77efc373968d1ed021024cc04"}, - {file = "yarl-1.9.2-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:662e6016409828ee910f5d9602a2729a8a57d74b163c89a837de3fea050c7582"}, - {file = "yarl-1.9.2-cp37-cp37m-win32.whl", hash = "sha256:f364d3480bffd3aa566e886587eaca7c8c04d74f6e8933f3f2c996b7f09bee1b"}, - {file = "yarl-1.9.2-cp37-cp37m-win_amd64.whl", hash = "sha256:6a5883464143ab3ae9ba68daae8e7c5c95b969462bbe42e2464d60e7e2698368"}, - {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:5610f80cf43b6202e2c33ba3ec2ee0a2884f8f423c8f4f62906731d876ef4fac"}, - {file = "yarl-1.9.2-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:b9a4e67ad7b646cd6f0938c7ebfd60e481b7410f574c560e455e938d2da8e0f4"}, - {file = "yarl-1.9.2-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:83fcc480d7549ccebe9415d96d9263e2d4226798c37ebd18c930fce43dfb9574"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fcd436ea16fee7d4207c045b1e340020e58a2597301cfbcfdbe5abd2356c2fb"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84e0b1599334b1e1478db01b756e55937d4614f8654311eb26012091be109d59"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:3458a24e4ea3fd8930e934c129b676c27452e4ebda80fbe47b56d8c6c7a63a9e"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:838162460b3a08987546e881a2bfa573960bb559dfa739e7800ceeec92e64417"}, - {file = "yarl-1.9.2-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f4e2d08f07a3d7d3e12549052eb5ad3eab1c349c53ac51c209a0e5991bbada78"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:de119f56f3c5f0e2fb4dee508531a32b069a5f2c6e827b272d1e0ff5ac040333"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:149ddea5abf329752ea5051b61bd6c1d979e13fbf122d3a1f9f0c8be6cb6f63c"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:674ca19cbee4a82c9f54e0d1eee28116e63bc6fd1e96c43031d11cbab8b2afd5"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:9b3152f2f5677b997ae6c804b73da05a39daa6a9e85a512e0e6823d81cdad7cc"}, - {file = "yarl-1.9.2-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:5415d5a4b080dc9612b1b63cba008db84e908b95848369aa1da3686ae27b6d2b"}, - {file = "yarl-1.9.2-cp38-cp38-win32.whl", hash = "sha256:f7a3d8146575e08c29ed1cd287068e6d02f1c7bdff8970db96683b9591b86ee7"}, - {file = "yarl-1.9.2-cp38-cp38-win_amd64.whl", hash = "sha256:63c48f6cef34e6319a74c727376e95626f84ea091f92c0250a98e53e62c77c72"}, - {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:75df5ef94c3fdc393c6b19d80e6ef1ecc9ae2f4263c09cacb178d871c02a5ba9"}, - {file = "yarl-1.9.2-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c027a6e96ef77d401d8d5a5c8d6bc478e8042f1e448272e8d9752cb0aff8b5c8"}, - {file = "yarl-1.9.2-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:f3b078dbe227f79be488ffcfc7a9edb3409d018e0952cf13f15fd6512847f3f7"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:59723a029760079b7d991a401386390c4be5bfec1e7dd83e25a6a0881859e716"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b03917871bf859a81ccb180c9a2e6c1e04d2f6a51d953e6a5cdd70c93d4e5a2a"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:c1012fa63eb6c032f3ce5d2171c267992ae0c00b9e164efe4d73db818465fac3"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a74dcbfe780e62f4b5a062714576f16c2f3493a0394e555ab141bf0d746bb955"}, - {file = "yarl-1.9.2-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:8c56986609b057b4839968ba901944af91b8e92f1725d1a2d77cbac6972b9ed1"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:2c315df3293cd521033533d242d15eab26583360b58f7ee5d9565f15fee1bef4"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:b7232f8dfbd225d57340e441d8caf8652a6acd06b389ea2d3222b8bc89cbfca6"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:53338749febd28935d55b41bf0bcc79d634881195a39f6b2f767870b72514caf"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:066c163aec9d3d073dc9ffe5dd3ad05069bcb03fcaab8d221290ba99f9f69ee3"}, - {file = "yarl-1.9.2-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:8288d7cd28f8119b07dd49b7230d6b4562f9b61ee9a4ab02221060d21136be80"}, - {file = "yarl-1.9.2-cp39-cp39-win32.whl", hash = "sha256:b124e2a6d223b65ba8768d5706d103280914d61f5cae3afbc50fc3dfcc016623"}, - {file = "yarl-1.9.2-cp39-cp39-win_amd64.whl", hash = "sha256:61016e7d582bc46a5378ffdd02cd0314fb8ba52f40f9cf4d9a5e7dbef88dee18"}, - {file = "yarl-1.9.2.tar.gz", hash = "sha256:04ab9d4b9f587c06d801c2abfe9317b77cdf996c65a90d5e84ecc45010823571"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:a8c1df72eb746f4136fe9a2e72b0c9dc1da1cbd23b5372f94b5820ff8ae30e0e"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a3a6ed1d525bfb91b3fc9b690c5a21bb52de28c018530ad85093cc488bee2dd2"}, + {file = "yarl-1.9.4-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:c38c9ddb6103ceae4e4498f9c08fac9b590c5c71b0370f98714768e22ac6fa66"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d9e09c9d74f4566e905a0b8fa668c58109f7624db96a2171f21747abc7524234"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b8477c1ee4bd47c57d49621a062121c3023609f7a13b8a46953eb6c9716ca392"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:d5ff2c858f5f6a42c2a8e751100f237c5e869cbde669a724f2062d4c4ef93551"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:357495293086c5b6d34ca9616a43d329317feab7917518bc97a08f9e55648455"}, + {file = "yarl-1.9.4-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:54525ae423d7b7a8ee81ba189f131054defdb122cde31ff17477951464c1691c"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_aarch64.whl", hash = "sha256:801e9264d19643548651b9db361ce3287176671fb0117f96b5ac0ee1c3530d53"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:e516dc8baf7b380e6c1c26792610230f37147bb754d6426462ab115a02944385"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:7d5aaac37d19b2904bb9dfe12cdb08c8443e7ba7d2852894ad448d4b8f442863"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_s390x.whl", hash = "sha256:54beabb809ffcacbd9d28ac57b0db46e42a6e341a030293fb3185c409e626b8b"}, + {file = "yarl-1.9.4-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:bac8d525a8dbc2a1507ec731d2867025d11ceadcb4dd421423a5d42c56818541"}, + {file = "yarl-1.9.4-cp310-cp310-win32.whl", hash = "sha256:7855426dfbddac81896b6e533ebefc0af2f132d4a47340cee6d22cac7190022d"}, + {file = "yarl-1.9.4-cp310-cp310-win_amd64.whl", hash = "sha256:848cd2a1df56ddbffeb375535fb62c9d1645dde33ca4d51341378b3f5954429b"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:35a2b9396879ce32754bd457d31a51ff0a9d426fd9e0e3c33394bf4b9036b099"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:4c7d56b293cc071e82532f70adcbd8b61909eec973ae9d2d1f9b233f3d943f2c"}, + {file = "yarl-1.9.4-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:d8a1c6c0be645c745a081c192e747c5de06e944a0d21245f4cf7c05e457c36e0"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:4b3c1ffe10069f655ea2d731808e76e0f452fc6c749bea04781daf18e6039525"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:549d19c84c55d11687ddbd47eeb348a89df9cb30e1993f1b128f4685cd0ebbf8"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:a7409f968456111140c1c95301cadf071bd30a81cbd7ab829169fb9e3d72eae9"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e23a6d84d9d1738dbc6e38167776107e63307dfc8ad108e580548d1f2c587f42"}, + {file = "yarl-1.9.4-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d8b889777de69897406c9fb0b76cdf2fd0f31267861ae7501d93003d55f54fbe"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_aarch64.whl", hash = "sha256:03caa9507d3d3c83bca08650678e25364e1843b484f19986a527630ca376ecce"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:4e9035df8d0880b2f1c7f5031f33f69e071dfe72ee9310cfc76f7b605958ceb9"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:c0ec0ed476f77db9fb29bca17f0a8fcc7bc97ad4c6c1d8959c507decb22e8572"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_s390x.whl", hash = "sha256:ee04010f26d5102399bd17f8df8bc38dc7ccd7701dc77f4a68c5b8d733406958"}, + {file = "yarl-1.9.4-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:49a180c2e0743d5d6e0b4d1a9e5f633c62eca3f8a86ba5dd3c471060e352ca98"}, + {file = "yarl-1.9.4-cp311-cp311-win32.whl", hash = "sha256:81eb57278deb6098a5b62e88ad8281b2ba09f2f1147c4767522353eaa6260b31"}, + {file = "yarl-1.9.4-cp311-cp311-win_amd64.whl", hash = "sha256:d1d2532b340b692880261c15aee4dc94dd22ca5d61b9db9a8a361953d36410b1"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:0d2454f0aef65ea81037759be5ca9947539667eecebca092733b2eb43c965a81"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:44d8ffbb9c06e5a7f529f38f53eda23e50d1ed33c6c869e01481d3fafa6b8142"}, + {file = "yarl-1.9.4-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:aaaea1e536f98754a6e5c56091baa1b6ce2f2700cc4a00b0d49eca8dea471074"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3777ce5536d17989c91696db1d459574e9a9bd37660ea7ee4d3344579bb6f129"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:9fc5fc1eeb029757349ad26bbc5880557389a03fa6ada41703db5e068881e5f2"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ea65804b5dc88dacd4a40279af0cdadcfe74b3e5b4c897aa0d81cf86927fee78"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:aa102d6d280a5455ad6a0f9e6d769989638718e938a6a0a2ff3f4a7ff8c62cc4"}, + {file = "yarl-1.9.4-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:09efe4615ada057ba2d30df871d2f668af661e971dfeedf0c159927d48bbeff0"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:008d3e808d03ef28542372d01057fd09168419cdc8f848efe2804f894ae03e51"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:6f5cb257bc2ec58f437da2b37a8cd48f666db96d47b8a3115c29f316313654ff"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:992f18e0ea248ee03b5a6e8b3b4738850ae7dbb172cc41c966462801cbf62cf7"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_s390x.whl", hash = "sha256:0e9d124c191d5b881060a9e5060627694c3bdd1fe24c5eecc8d5d7d0eb6faabc"}, + {file = "yarl-1.9.4-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:3986b6f41ad22988e53d5778f91855dc0399b043fc8946d4f2e68af22ee9ff10"}, + {file = "yarl-1.9.4-cp312-cp312-win32.whl", hash = "sha256:4b21516d181cd77ebd06ce160ef8cc2a5e9ad35fb1c5930882baff5ac865eee7"}, + {file = "yarl-1.9.4-cp312-cp312-win_amd64.whl", hash = "sha256:a9bd00dc3bc395a662900f33f74feb3e757429e545d831eef5bb280252631984"}, + {file = "yarl-1.9.4-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:63b20738b5aac74e239622d2fe30df4fca4942a86e31bf47a81a0e94c14df94f"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:d7d7f7de27b8944f1fee2c26a88b4dabc2409d2fea7a9ed3df79b67277644e17"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:c74018551e31269d56fab81a728f683667e7c28c04e807ba08f8c9e3bba32f14"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:ca06675212f94e7a610e85ca36948bb8fc023e458dd6c63ef71abfd482481aa5"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5aef935237d60a51a62b86249839b51345f47564208c6ee615ed2a40878dccdd"}, + {file = "yarl-1.9.4-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2b134fd795e2322b7684155b7855cc99409d10b2e408056db2b93b51a52accc7"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_aarch64.whl", hash = "sha256:d25039a474c4c72a5ad4b52495056f843a7ff07b632c1b92ea9043a3d9950f6e"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:f7d6b36dd2e029b6bcb8a13cf19664c7b8e19ab3a58e0fefbb5b8461447ed5ec"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:957b4774373cf6f709359e5c8c4a0af9f6d7875db657adb0feaf8d6cb3c3964c"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_s390x.whl", hash = "sha256:d7eeb6d22331e2fd42fce928a81c697c9ee2d51400bd1a28803965883e13cead"}, + {file = "yarl-1.9.4-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:6a962e04b8f91f8c4e5917e518d17958e3bdee71fd1d8b88cdce74dd0ebbf434"}, + {file = "yarl-1.9.4-cp37-cp37m-win32.whl", hash = "sha256:f3bc6af6e2b8f92eced34ef6a96ffb248e863af20ef4fde9448cc8c9b858b749"}, + {file = "yarl-1.9.4-cp37-cp37m-win_amd64.whl", hash = "sha256:ad4d7a90a92e528aadf4965d685c17dacff3df282db1121136c382dc0b6014d2"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:ec61d826d80fc293ed46c9dd26995921e3a82146feacd952ef0757236fc137be"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:8be9e837ea9113676e5754b43b940b50cce76d9ed7d2461df1af39a8ee674d9f"}, + {file = "yarl-1.9.4-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:bef596fdaa8f26e3d66af846bbe77057237cb6e8efff8cd7cc8dff9a62278bbf"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2d47552b6e52c3319fede1b60b3de120fe83bde9b7bddad11a69fb0af7db32f1"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:84fc30f71689d7fc9168b92788abc977dc8cefa806909565fc2951d02f6b7d57"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:4aa9741085f635934f3a2583e16fcf62ba835719a8b2b28fb2917bb0537c1dfa"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:206a55215e6d05dbc6c98ce598a59e6fbd0c493e2de4ea6cc2f4934d5a18d130"}, + {file = "yarl-1.9.4-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:07574b007ee20e5c375a8fe4a0789fad26db905f9813be0f9fef5a68080de559"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_aarch64.whl", hash = "sha256:5a2e2433eb9344a163aced6a5f6c9222c0786e5a9e9cac2c89f0b28433f56e23"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:6ad6d10ed9b67a382b45f29ea028f92d25bc0bc1daf6c5b801b90b5aa70fb9ec"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:6fe79f998a4052d79e1c30eeb7d6c1c1056ad33300f682465e1b4e9b5a188b78"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_s390x.whl", hash = "sha256:a825ec844298c791fd28ed14ed1bffc56a98d15b8c58a20e0e08c1f5f2bea1be"}, + {file = "yarl-1.9.4-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:8619d6915b3b0b34420cf9b2bb6d81ef59d984cb0fde7544e9ece32b4b3043c3"}, + {file = "yarl-1.9.4-cp38-cp38-win32.whl", hash = "sha256:686a0c2f85f83463272ddffd4deb5e591c98aac1897d65e92319f729c320eece"}, + {file = "yarl-1.9.4-cp38-cp38-win_amd64.whl", hash = "sha256:a00862fb23195b6b8322f7d781b0dc1d82cb3bcac346d1e38689370cc1cc398b"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:604f31d97fa493083ea21bd9b92c419012531c4e17ea6da0f65cacdcf5d0bd27"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8a854227cf581330ffa2c4824d96e52ee621dd571078a252c25e3a3b3d94a1b1"}, + {file = "yarl-1.9.4-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:ba6f52cbc7809cd8d74604cce9c14868306ae4aa0282016b641c661f981a6e91"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a6327976c7c2f4ee6816eff196e25385ccc02cb81427952414a64811037bbc8b"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:8397a3817d7dcdd14bb266283cd1d6fc7264a48c186b986f32e86d86d35fbac5"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:e0381b4ce23ff92f8170080c97678040fc5b08da85e9e292292aba67fdac6c34"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:23d32a2594cb5d565d358a92e151315d1b2268bc10f4610d098f96b147370136"}, + {file = "yarl-1.9.4-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ddb2a5c08a4eaaba605340fdee8fc08e406c56617566d9643ad8bf6852778fc7"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_aarch64.whl", hash = "sha256:26a1dc6285e03f3cc9e839a2da83bcbf31dcb0d004c72d0730e755b33466c30e"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:18580f672e44ce1238b82f7fb87d727c4a131f3a9d33a5e0e82b793362bf18b4"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:29e0f83f37610f173eb7e7b5562dd71467993495e568e708d99e9d1944f561ec"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_s390x.whl", hash = "sha256:1f23e4fe1e8794f74b6027d7cf19dc25f8b63af1483d91d595d4a07eca1fb26c"}, + {file = "yarl-1.9.4-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:db8e58b9d79200c76956cefd14d5c90af54416ff5353c5bfd7cbe58818e26ef0"}, + {file = "yarl-1.9.4-cp39-cp39-win32.whl", hash = "sha256:c7224cab95645c7ab53791022ae77a4509472613e839dab722a72abe5a684575"}, + {file = "yarl-1.9.4-cp39-cp39-win_amd64.whl", hash = "sha256:824d6c50492add5da9374875ce72db7a0733b29c2394890aef23d533106e2b15"}, + {file = "yarl-1.9.4-py3-none-any.whl", hash = "sha256:928cecb0ef9d5a7946eb6ff58417ad2fe9375762382f1bf5c55e61645f2c43ad"}, + {file = "yarl-1.9.4.tar.gz", hash = "sha256:566db86717cf8080b99b58b083b773a908ae40f06681e87e589a976faf8246bf"}, ] [package.dependencies] idna = ">=2.0" multidict = ">=4.0" -typing-extensions = {version = ">=3.7.4", markers = "python_version < \"3.8\""} - -[[package]] -name = "zipp" -version = "3.15.0" -description = "Backport of pathlib-compatible object wrapper for zip files" -optional = false -python-versions = ">=3.7" -files = [ - {file = "zipp-3.15.0-py3-none-any.whl", hash = "sha256:48904fc76a60e542af151aded95726c1a5c34ed43ab4134b597665c86d7ad556"}, - {file = "zipp-3.15.0.tar.gz", hash = "sha256:112929ad649da941c23de50f356a2b5570c954b65150642bccdd66bf194d224b"}, -] - -[package.extras] -docs = ["furo", "jaraco.packaging (>=9)", "jaraco.tidelift (>=1.4)", "rst.linker (>=1.9)", "sphinx (>=3.5)", "sphinx-lint"] -testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=1.3)", "pytest-flake8", "pytest-mypy (>=0.9.1)"] [metadata] lock-version = "2.0" -python-versions = "^3.7" -content-hash = "deae6349cd55b6da7e03a9a858e7bbfb678e97982b34324cef3af0be5dfa3a4a" +python-versions = "^3.8" +content-hash = "4572f90730a8c15e31847b5238b491116780f50b02fd1b08e45d6353baac1bf8" diff --git a/clients/python/pyproject.toml b/clients/python/pyproject.toml index 9d1db1f1..4c2bd931 100644 --- a/clients/python/pyproject.toml +++ b/clients/python/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "scale-llm-engine" -version = "0.0.0.beta3" +version = "0.0.0.beta44" description = "Scale LLM Engine Python client" license = "Apache-2.0" authors = ["Phil Chen "] @@ -12,10 +12,11 @@ packages = [{include = "llmengine"}] [tool.poetry.dependencies] -python = "^3.7" -pydantic = "^1.10" +python = "^3.8" +pydantic = ">=1.10.17" aiohttp = "^3.8" requests = "^2.31.0" +openai = "^1.30.0" [tool.poetry.dev-dependencies] pytest = "^6.2.5" @@ -28,6 +29,7 @@ pytest-mypy-plugins = "^1.10.1" [tool.pytest.ini_options] asyncio_mode = "auto" + [build-system] requires = ["poetry-core>=1.0.0"] build-backend = "poetry.core.masonry.api" diff --git a/clients/python/setup.py b/clients/python/setup.py index 4d14ebd3..986694bb 100644 --- a/clients/python/setup.py +++ b/clients/python/setup.py @@ -2,7 +2,8 @@ setup( name="scale-llm-engine", - python_requires=">=3.7", - version="0.0.0.beta3", + python_requires=">=3.8", + version="0.0.0.beta44", packages=find_packages(), + package_data={"llmengine": ["py.typed"]}, ) diff --git a/docs/CNAME b/docs/CNAME new file mode 100644 index 00000000..bd01b7f9 --- /dev/null +++ b/docs/CNAME @@ -0,0 +1 @@ +llm-engine.scale.com diff --git a/docs/api/data_types.md b/docs/api/data_types.md index 55f33028..0576329c 100644 --- a/docs/api/data_types.md +++ b/docs/api/data_types.md @@ -1,14 +1,14 @@ # 🐍 Python Client Data Type Reference ::: llmengine.CompletionOutput - selection: + options: members: - text - num_prompt_tokens - num_completion_tokens ::: llmengine.CompletionStreamOutput - selection: + options: members: - text - finished @@ -16,30 +16,136 @@ - num_completion_tokens ::: llmengine.CompletionSyncResponse + options: + members: + - request_id + - output ::: llmengine.CompletionStreamResponse + options: + members: + - request_id + - output ::: llmengine.CreateFineTuneResponse + options: + members: + - id ::: llmengine.GetFineTuneResponse + options: + members: + - id + - fine_tuned_model ::: llmengine.ListFineTunesResponse + options: + members: + - jobs ::: llmengine.CancelFineTuneResponse + options: + members: + - success ::: llmengine.GetLLMEndpointResponse - selection: + options: members: - name - source - inference_framework - id - model_name + - status - inference_framework_tag - num_shards - quantize - spec ::: llmengine.ListLLMEndpointsResponse + options: + members: + - model_endpoints ::: llmengine.DeleteLLMEndpointResponse + options: + members: + - deleted + +::: llmengine.ModelDownloadRequest + options: + members: + - model_name + - download_format + +::: llmengine.ModelDownloadResponse + options: + members: + - urls + +::: llmengine.UploadFileResponse + options: + members: + - id + +::: llmengine.GetFileResponse + options: + members: + - id + - filename + - size + +::: llmengine.GetFileContentResponse + options: + members: + - id + - content + +::: llmengine.ListFilesResponse + options: + members: + - files + +::: llmengine.DeleteFileResponse + options: + members: + - deleted + +::: llmengine.CreateBatchCompletionsRequestContent + options: + members: + - prompts + - max_new_tokens + - temperature + - stop_sequences + - return_token_log_probs + - presence_penalty + - frequency_penalty + - top_k + - top_p + +::: llmengine.CreateBatchCompletionsModelConfig + options: + members: + - model + - checkpoint_path + - labels + - num_shards + - quantize + - seed + +::: llmengine.CreateBatchCompletionsRequest + options: + members: + - input_data_path + - output_data_path + - content + - model_config + - data_parallelism + - max_runtime_sec + - tool_config + +::: llmengine.CreateBatchCompletionsResponse + options: + members: + - job_id diff --git a/docs/api/python_client.md b/docs/api/python_client.md index 820b2e56..3e338388 100644 --- a/docs/api/python_client.md +++ b/docs/api/python_client.md @@ -1,13 +1,14 @@ # 🐍 Python Client API Reference ::: llmengine.Completion - selection: + options: members: - create - acreate + - batch_create ::: llmengine.FineTune - selection: + options: members: - create - get @@ -16,8 +17,20 @@ - cancel ::: llmengine.Model - selection: + options: members: + - create + - get + - list + - update + - delete + - download + +::: llmengine.File + options: + members: + - upload - get + - download - list - delete diff --git a/docs/contributing.md b/docs/contributing.md index 37a6793a..8423c202 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -21,7 +21,7 @@ pip install -r requirements-docs.txt Our Python client API reference is autogenerated from our client. You can install the client in editable mode with ``` -pip install -r clients/python +pip install -e clients/python ``` ### Step 4: Run Locally diff --git a/docs/examples/finetuning.ipynb b/docs/examples/finetuning.ipynb index 48669318..16573392 100644 --- a/docs/examples/finetuning.ipynb +++ b/docs/examples/finetuning.ipynb @@ -44,7 +44,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 1, "metadata": {}, "outputs": [], "source": [ @@ -70,7 +70,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 2, "metadata": {}, "outputs": [ { @@ -79,7 +79,7 @@ "\"From: dougb@comm.mot.com (Doug Bank)\\nSubject: Re: Info needed for Cleveland tickets\\nReply-To: dougb@ecs.comm.mot.com\\nOrganization: Motorola Land Mobile Products Sector\\nDistribution: usa\\nNntp-Posting-Host: 145.1.146.35\\nLines: 17\\n\\nIn article <1993Apr1.234031.4950@leland.Stanford.EDU>, bohnert@leland.Stanford.EDU (matthew bohnert) writes:\\n\\n|> I'm going to be in Cleveland Thursday, April 15 to Sunday, April 18.\\n|> Does anybody know if the Tribe will be in town on those dates, and\\n|> if so, who're they playing and if tickets are available?\\n\\nThe tribe will be in town from April 16 to the 19th.\\nThere are ALWAYS tickets available! (Though they are playing Toronto,\\nand many Toronto fans make the trip to Cleveland as it is easier to\\nget tickets in Cleveland than in Toronto. Either way, I seriously\\ndoubt they will sell out until the end of the season.)\\n\\n-- \\nDoug Bank Private Systems Division\\ndougb@ecs.comm.mot.com Motorola Communications Sector\\ndougb@nwu.edu Schaumburg, Illinois\\ndougb@casbah.acns.nwu.edu 708-576-8207\"" ] }, - "execution_count": 3, + "execution_count": 2, "metadata": {}, "output_type": "execute_result" } @@ -90,7 +90,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 3, "metadata": {}, "outputs": [ { @@ -102,7 +102,7 @@ "Name: count, dtype: int64" ] }, - "execution_count": 4, + "execution_count": 3, "metadata": {}, "output_type": "execute_result" } @@ -113,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 4, "metadata": {}, "outputs": [ { @@ -125,7 +125,7 @@ "Name: count, dtype: int64" ] }, - "execution_count": 5, + "execution_count": 4, "metadata": {}, "output_type": "execute_result" } @@ -143,7 +143,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 5, "metadata": {}, "outputs": [], "source": [ @@ -167,7 +167,7 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 7, "metadata": {}, "outputs": [ { @@ -234,7 +234,7 @@ "4 baseball Prompt: Subject: Let it be Known\\nFrom: 974\n", " From: maX <maX@maxim.rinaco.msk.su>\\nSubject: ...\n", " hockey\n", - " \n", - " \n", - " \n", - " 988\n", - " From: jca2@cec1.wustl.edu (Joseph Charles Achk...\n", - " hockey\n", - " NHL\n", - " \n", - " \n", - " 997\n", - " From: apland@mala.bc.ca (Ron Apland)\\nSubject:...\n", - " hockey\n", - " \n", + " baseball\n", " \n", " \n", "\n", @@ -587,16 +575,12 @@ "text/plain": [ " raw_prompt response \\\n", "974 From: maX \\nSubject: ... hockey \n", - "988 From: jca2@cec1.wustl.edu (Joseph Charles Achk... hockey \n", - "997 From: apland@mala.bc.ca (Ron Apland)\\nSubject:... hockey \n", "\n", " predicted_response \n", - "974 \n", - "988 NHL \n", - "997 " + "974 baseball " ] }, - "execution_count": 23, + "execution_count": 19, "metadata": {}, "output_type": "execute_result" } diff --git a/docs/getting_started.md b/docs/getting_started.md index 23ef79f3..5dd3d422 100644 --- a/docs/getting_started.md +++ b/docs/getting_started.md @@ -1,8 +1,8 @@ # Getting Started -The fastest way to get started with LLM Engine is to use the Python client in this repository to -run inference and fine-tuning on Scale's infrastructure. This path does not require you to install -anything on your infrastructure, and Scale's free research preview gives you access to experimentation using open source LLMs. +**Note: As of October 31st 2024, LLM Engine's public demo service is sunsetted. We have thus removed the documentation +pieces relating to calling the demo service, procuring a Spellbook API key, etc. Please view our Self Hosting Guide instead. +We will however leave behind the Example Code snippets for posterity, and as a reference for self-hosted and Scale internal users.** To start, install LLM Engine via pip: @@ -11,32 +11,32 @@ To start, install LLM Engine via pip: pip install scale-llm-engine ``` -## Scale API Keys +## Scale user ID -Next, you need a Scale Spellbook API key. +Next, you need a Scale user ID. Recall that this is only applicable to Scale internal users for now, and we are just leaving +this note to serve as internal documentation. -### Retrieving your API Key - -To retrieve your API key, head to [Scale Spellbook](https://spellbook.scale.com) where -you will get an API key on the [settings](https://spellbook.scale.com/settings) page. - -!!! note "Different API Keys for different Scale Products" - - If you have leveraged Scale's platform for annotation work in the past, please note that your Spellbook API key will be different than the Scale Annotation API key. You will want to create a Spellbook API key before getting started. ### Set your API Key LLM Engine uses environment variables to access your API key. -Set this API key as the `SCALE_API_KEY` environment variable by running the following command in your terminal before you run your python application. +Set the `SCALE_API_KEY` environment variable to your Scale user ID by running the following command in your terminal before you run your python application. ``` -export SCALE_API_KEY="[Your API key]" +export SCALE_API_KEY="[Your Scale user ID]" ``` You can also add in the line above to your `.zshrc` or `.bash_profile` so it's automatically set for future sessions. +Alternatively, you can also set your API key using either of the following patterns: +``` +llmengine.api_engine.api_key = "abc" +llmengine.api_engine.set_api_key("abc") +``` +These patterns are useful for Jupyter Notebook users to set API keys without the need for using `os.environ`. + ## Example Code ### Sample Completion @@ -48,7 +48,7 @@ With your API key set, you can now send LLM Engine requests using the Python cli from llmengine import Completion response = Completion.create( - model="falcon-7b-instruct", + model="llama-2-7b", prompt="I'm opening a pancake restaurant that specializes in unique pancake shapes, colors, and flavors. List 3 quirky names I could name my restaurant.", max_new_tokens=100, temperature=0.2, @@ -66,7 +66,7 @@ import sys from llmengine import Completion stream = Completion.create( - model="falcon-7b-instruct", + model="llama-2-7b", prompt="Give me a 200 word summary on the current economic events in the US.", max_new_tokens=1000, temperature=0.2, @@ -77,4 +77,7 @@ for response in stream: if response.output: print(response.output.text, end="") sys.stdout.flush() + else: # an error occurred + print(response.error) # print the error message out + break ``` diff --git a/docs/guides/completions.md b/docs/guides/completions.md index e5b6fdde..56b4538d 100644 --- a/docs/guides/completions.md +++ b/docs/guides/completions.md @@ -1,21 +1,22 @@ -Language Models are trained to predict natural language and provide text outputs as a response -to their inputs. The inputs are called _prompts_ and outputs are referred to as _completions_. -LLMs take the input _prompts_ and chunk them into smaller units called _tokens_ to process and -generate language. Tokens may include trailing spaces and even sub-words. This process is +Language Models are trained to predict natural language and provide text outputs as a response +to their inputs. The inputs are called _prompts_ and outputs are referred to as _completions_. +LLMs take the input _prompts_ and chunk them into smaller units called _tokens_ to process and +generate language. Tokens may include trailing spaces and even sub-words. This process is language dependent. -Scale's LLM Engine provides access to open source language models (see [Model Zoo](../../model_zoo)) +Scale's LLM Engine provides access to open source language models (see [Model Zoo](../../model_zoo)) that can be used for producing completions to prompts. ## Completion API call An example API call looks as follows: +=== "Completion call in Python" ```python from llmengine import Completion response = Completion.create( - model="llama-7b", + model="llama-2-7b", prompt="Hello, my name is", max_new_tokens=10, temperature=0.2, @@ -28,11 +29,12 @@ print(response.output.text) # ________ and I am a ________ ``` -- **model:** The LLM you want to use (see [Model Zoo](../../model_zoo)). -- **prompt:** The main input for the LLM to respond to. -- **max_new_tokens:** The maximum number of tokens to generate in the chat completion. -- **temperature:** The sampling temperature to use. Higher values make the output more random, -while lower values will make it more focused and deterministic. +- **model:** The LLM you want to use (see [Model Zoo](../../model_zoo)). +- **prompt:** The main input for the LLM to respond to. +- **max_new_tokens:** The maximum number of tokens to generate in the chat completion. +- **temperature:** The sampling temperature to use. Higher values make the output more random, + while lower values will make it more focused and deterministic. + When temperature is 0 [greedy search](https://huggingface.co/docs/transformers/generation_strategies#greedy-search) is used. See the full [Completion API reference documentation](../../api/python_client/#llmengine.Completion) to learn more. @@ -42,45 +44,47 @@ An example Completion API response looks as follows: === "Response in JSON" ```python - >>> print(response.json()) - ``` - Example output: - ```json - { - "request_id": "c4bf0732-08e0-48a8-8b44-dfe8d4702fb0", - "output": { - "text": "_______ and I am a _______", - "num_completion_tokens": 10 - } - } + >>> print(response.json()) + { + "request_id": "c4bf0732-08e0-48a8-8b44-dfe8d4702fb0", + "output": { + "text": "_______ and I am a _______", + "num_completion_tokens": 10 + } + } ``` === "Response in Python" ```python - >>> print(response.output.text) - ``` - Example output: - ``` - _______ and I am a _______ + >>> print(response.output.text) + _______ and I am a _______ ``` ## Token streaming -The Completions API supports token streaming to reduce _perceived_ latency for certain -applications. When streaming, tokens will be sent as data-only +The Completions API supports token streaming to reduce _perceived_ latency for certain +applications. When streaming, tokens will be sent as data-only [server-side events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#event_stream_format). To enable token streaming, pass `stream=True` to either [Completion.create](../../api/python_client/#llmengine.completion.Completion.create) or [Completion.acreate](../../api/python_client/#llmengine.completion.Completion.acreate). +### Streaming Error Handling + +Note: Error handling semantics are mixed for streaming calls: +- Errors that arise *before* streaming begins are returned back to the user as `HTTP` errors with the appropriate status code. +- Errors that arise *after* streaming begins within a `HTTP 200` response are returned back to the user as plain-text messages and currently need to be handled by the client. + An example of token streaming using the synchronous Completions API looks as follows: === "Token streaming with synchronous API in python" + ```python import sys from llmengine import Completion +# errors occurring before streaming begins will be thrown here stream = Completion.create( - model="falcon-7b-instruct", + model="llama-2-7b", prompt="Give me a 200 word summary on the current economic events in the US.", max_new_tokens=1000, temperature=0.2, @@ -91,6 +95,9 @@ for response in stream: if response.output: print(response.output.text, end="") sys.stdout.flush() + else: # an error occurred after streaming began + print(response.error) # print the error message out + break ``` ## Async requests @@ -101,13 +108,14 @@ to utilize async processing. The function signatures are otherwise identical. An example of async Completions looks as follows: === "Completions with asynchronous API in python" + ```python import asyncio from llmengine import Completion async def main(): response = await Completion.acreate( - model="llama-7b", + model="llama-2-7b", prompt="Hello, my name is", max_new_tokens=10, temperature=0.2, @@ -117,6 +125,149 @@ async def main(): asyncio.run(main()) ``` +## Batch completions + +The Python client also supports batch completions. Batch completions supports distributing data to multiple workers to accelerate inference. It also tries to maximize throughput so the completions should finish quite a bit faster than hitting models through HTTP. Use [Completion.batch_create](../../api/python_client/#llmengine.Completion.batch_create) to utilize batch completions. + +Some examples of batch completions: + +=== "Batch completions with prompts in the request" +```python +from llmengine import Completion +from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent + +content = CreateBatchCompletionsRequestContent( + prompts=["What is deep learning", "What is a neural network"], + max_new_tokens=10, + temperature=0.0 +) + +response = Completion.batch_create( + output_data_path="s3://my-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + content=content +) +print(response.job_id) +``` + +=== "Batch completions with prompts in a file and with 2 parallel jobs" +```python +from llmengine import Completion +from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent + +# Store CreateBatchCompletionsRequestContent data into input file "s3://my-input-path" + +response = Completion.batch_create( + input_data_path="s3://my-input-path", + output_data_path="s3://my-output-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + data_parallelism=2 +) +print(response.job_id) +``` + +=== "Batch completions with prompts and use tool" +For how to properly use the tool please see [Completion.batch_create](../../api/python_client/#llmengine.Completion.batch_create) tool_config doc. +```python +from llmengine import Completion +from llmengine.data_types import CreateBatchCompletionsModelConfig, CreateBatchCompletionsRequestContent, ToolConfig + +# Store CreateBatchCompletionsRequestContent data into input file "s3://my-input-path" + +response = Completion.batch_create( + input_data_path="s3://my-input-path", + output_data_path="s3://my-output-path", + model_config=CreateBatchCompletionsModelConfig( + model="llama-2-7b", + checkpoint_path="s3://checkpoint-path", + labels={"team":"my-team", "product":"my-product"} + ), + data_parallelism=2, + tool_config=ToolConfig( + name="code_evaluator", + ) +) +print(response.json()) +``` + +## Guided decoding + +Guided decoding is supported by vLLM and backed by [Outlines](https://github.com/outlines-dev/outlines). +It enforces certain token generation patterns by tinkering with the sampling logits. + +=== "Guided decoding with regex" +```python +from llmengine import Completion + +response = Completion.create( + model="llama-2-7b", + prompt="Hello, my name is", + max_new_tokens=10, + temperature=0.2, + guided_regex="Sean.*", +) + +print(response.json()) +# {"request_id":"c19f0fae-317e-4f69-8e06-c04189299b9c","output":{"text":"Sean. I'm a 2","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}} +``` + +=== "Guided decoding with choice" +```python +from llmengine import Completion + +response = Completion.create( + model="llama-2-7b", + prompt="Hello, my name is", + max_new_tokens=10, + temperature=0.2, + guided_choice=["Sean", "Brian", "Tim"], +) + +print(response.json()) +# {"request_id":"641e2af3-a3e3-4493-98b9-d38115ba0d22","output":{"text":"Sean","num_prompt_tokens":6,"num_completion_tokens":4,"tokens":null}} +``` + +=== "Guided decoding with JSON schema" +```python +from llmengine import Completion + +response = Completion.create( + model="llama-2-7b", + prompt="Hello, my name is", + max_new_tokens=10, + temperature=0.2, + guided_json={"properties":{"myString":{"type":"string"}},"required":["myString"]}, +) + +print(response.json()) +# {"request_id":"5b184654-96b6-4932-9eb6-382a51fdb3d5","output":{"text":"{\"myString\" : \"John Doe","num_prompt_tokens":6,"num_completion_tokens":10,"tokens":null}} +``` + +=== "Guided decoding with Context-Free Grammar" + +```python +from llmengine import Completion + +response = Completion.create( + model="llama-2-7b", + prompt="Hello, my name is", + max_new_tokens=10, + temperature=0.2, + guided_grammar="start: \"John\"" +) + +print(response.json()) +# {"request_id": "34621b44-c655-402c-a459-f108b3e49b12", "output": {"text": "John", "num_prompt_tokens": 6, "num_completion_tokens": 4, "tokens": None}} +``` + ## Which model should I use? See the [Model Zoo](../../model_zoo) for more information on best practices for which model to use for Completions. diff --git a/docs/guides/endpoint_creation.md b/docs/guides/endpoint_creation.md new file mode 100644 index 00000000..e16602b7 --- /dev/null +++ b/docs/guides/endpoint_creation.md @@ -0,0 +1,17 @@ +When creating a model endpoint, you can periodically poll the model status field to +track the status of your model endpoint. In general, you'll need to wait after the +model creation step for the model endpoint to be ready and available for use. +An example is provided below: + + +``` +model_name = "test_deploy" +model = Model.create(name=model_name, model="llama-2-7b", inference_frame_image_tag="0.9.4") +response = Model.get(model_name) +while response.status.name != "READY": + print(response.status.name) + time.sleep(60) + response = Model.get(model_name) +``` + +Once the endpoint status is ready, you can use your newly created model for inference. \ No newline at end of file diff --git a/docs/guides/fine_tuning.md b/docs/guides/fine_tuning.md index ab83d0d1..705f1b23 100644 --- a/docs/guides/fine_tuning.md +++ b/docs/guides/fine_tuning.md @@ -22,15 +22,15 @@ The training data for fine-tuning should consist of prompt and response pairs. As a rule of thumb, you should expect to see linear improvements in your fine-tuned model's quality with each doubling of the dataset size. Having high-quality data is also essential to improving performance. For every linear increase in the error rate in your training data, you may encounter a roughly quadratic increase in your fine-tuned model's error rate. -High quality data is critical to achieve improved model performance, and in several cases will require _experts_ to -generate and prepare data - the breadth and diversity of the data is highly critical. Scale's Data Engine can help +High quality data is critical to achieve improved model performance, and in several cases will require _experts_ to +generate and prepare data - the breadth and diversity of the data is highly critical. Scale's Data Engine can help prepare such high quality, diverse data sets - more information [here](https://scale.com/rlhf). ## Preparing data -Your data must be formatted as a CSV file that includes two columns: `prompt` and `response`. A maximum of 100,000 rows of data is currently supported. At least 200 rows of data is recommended to start to see benefits from fine-tuning. -Here is an example script to create a 50-row CSV of properly formatted data for fine-tuning an airline question answering bot +Your data must be formatted as a CSV file that includes two columns: `prompt` and `response`. A maximum of 100,000 rows of data is currently supported. At least 200 rows of data is recommended to start to see benefits from fine-tuning. LLM Engine supports fine-tuning with a training and validation dataset. If only a training dataset is provided, 10% of the data is randomly split to be used as validation. +Here is an example script to create a 50-row CSV of properly formatted data for fine-tuning an airline question answering bot
Creating a sample dataset @@ -98,44 +98,51 @@ with open('customer_service_data.csv', 'w', newline='') as file: writer.writerow(["prompt", "response"]) writer.writerows(data) ``` +
## Making your data accessible to LLM Engine -Currently, data needs to be uploaded to a publicly accessible web URL so that it can be read -for fine-tuning. Publicly accessible HTTP and HTTPS URLs are currently supported. -Support for privately sharing data with the LLM Engine API is coming shortly. For quick -iteration, you can look into tools like Pastebin or GitHub Gists to quickly host your CSV -files in a public manner. An example Github Gist can be found -[here](https://gist.github.com/tigss/7cec73251a37de72756a3b15eace9965). To use the gist, -you can use the URL given when you click the “Raw” button -([URL](https://gist.githubusercontent.com/tigss/7cec73251a37de72756a3b15eace9965/raw/85d9742890e1e6b0c06468507292893b820c13c9/llm_sample_data.csv)). +Currently, data needs to be uploaded to either a publicly accessible web URL or to LLM Engine's private file server so that it can be read for fine-tuning. Publicly accessible HTTP and HTTPS URLs are currently supported. + +To privately share data with the LLM Engine API, use LLM Engine's [File.upload](../../api/python_client/#llmengine.File.upload) API. You can upload data in local file to LLM Engine's private file server and then use the returned file ID to reference your data in the FineTune API. The file ID is generally in the form of `file-`, e.g. "file-7DLVeLdN2Ty4M2m". + +=== "Upload to LLM Engine's private file server" + +```python +from llmengine import File + +response = File.upload(open("customer_service_data.csv", "r")) +print(response.json()) +``` ## Launching the fine-tune -Once you have uploaded your data, you can use the LLM Engine's [FineTune.Create](../../api/python_client/#llmengine.fine_tuning.FineTune.create) API to launch a fine-tune. You will need to specify which base model to fine-tune, the locations of the training file and optional validation data file, an optional set of hyperparameters to customize the fine-tuning behavior, and an optional suffix to append to the name of the fine-tune. For sequences longer than the native + +Once you have uploaded your data, you can use the LLM Engine's [FineTune.Create](../../api/python_client/#llmengine.fine_tuning.FineTune.create) API to launch a fine-tune. You will need to specify which base model to fine-tune, the locations of the training file and optional validation data file, an optional set of hyperparameters to customize the fine-tuning behavior, and an optional suffix to append to the name of the fine-tune. For sequences longer than the native `max_seq_length` of the model, the sequences will be truncated. -If you specify a suffix, the fine-tune will be named `model.suffix.`. If you do not, -the fine-tune will be named `model.`. The timestamp will be the time the fine-tune was -launched. +If you specify a suffix, the fine-tune will be named `model.suffix.`. If you do not, +the fine-tune will be named `model.`. The timestamp will be the time the fine-tune was +launched. Note: the suffix must only contain alphanumeric characters and hyphens, and be at most 28 characters long.
Hyper-parameters for fine-tune -* `lr`: Peak learning rate used during fine-tuning. It decays with a cosine schedule afterward. (Default: 2e-3) -* `warmup_ratio`: Ratio of training steps used for learning rate warmup. (Default: 0.03) -* `epochs`: Number of fine-tuning epochs. This should be less than 20. (Default: 5) -* `weight_decay`: Regularization penalty applied to learned weights. (Default: 0.001) +- `lr`: Peak learning rate used during fine-tuning. It decays with a cosine schedule afterward. (Default: 2e-3) +- `warmup_ratio`: Ratio of training steps used for learning rate warmup. (Default: 0.03) +- `epochs`: Number of fine-tuning epochs. This should be less than 20. (Default: 5) +- `weight_decay`: Regularization penalty applied to learned weights. (Default: 0.001)
-=== "Create a fine-tune in python" +=== "Create a fine-tune in python" ```python from llmengine import FineTune response = FineTune.create( - model="llama-7b", - training_file="s3://my-bucket/path/to/training-file.csv", + model="llama-2-7b", + training_file="file-AbCDeLdN2Ty4M2m", + validation_file="file-ezSRpgtKQyItI26", ) print(response.json()) @@ -143,18 +150,48 @@ print(response.json()) See the [Model Zoo](../../model_zoo) to see which models have fine-tuning support. -Once the fine-tune is launched, you can also [get the status of your fine-tune](../../api/python_client/#llmengine.fine_tuning.FineTune.get). You can also [list events that your fine-tune produces](../../api/python_client/#llmengine.fine_tuning.FineTune.get_events). +See [Integrations](../integrations.md) to see how to track fine-tuning metrics. + +## Monitoring the fine-tune + +Once the fine-tune is launched, you can also [get the status of your fine-tune](../../api/python_client/#llmengine.fine_tuning.FineTune.get). +You can also [list events that your fine-tune produces](../../api/python_client/#llmengine.fine_tuning.FineTune.get_events). +```python +from llmengine import FineTune + +fine_tune_id = "ft-cabcdefghi1234567890" +fine_tune = FineTune.get(fine_tune_id) +print(fine_tune.status) # BatchJobStatus.RUNNING +print(fine_tune.fine_tuned_model) # "llama-2-7b.700101-000000 + +fine_tune_events = FineTune.get_events(fine_tune_id) +for event in fine_tune_events.events: + print(event) +# Prints something like: +# timestamp=1697590000.0 message="{'loss': 12.345, 'learning_rate': 0.0, 'epoch': 0.97}" level='info' +# timestamp=1697590000.0 message="{'eval_loss': 23.456, 'eval_runtime': 19.876, 'eval_samples_per_second': 4.9, 'eval_steps_per_second': 4.9, 'epoch': 0.97}" level='info' +# timestamp=1697590020.0 message="{'train_runtime': 421.234, 'train_samples_per_second': 2.042, 'train_steps_per_second': 0.042, 'total_flos': 123.45, 'train_loss': 34.567, 'epoch': 0.97}" level='info' + + +``` + +The status of your fine-tune will give a high-level overview of the fine-tune's progress. +The events of your fine-tune will give more detail, such as the training loss and validation loss at each epoch, +as well as any errors that may have occurred. If you encounter any errors with your fine-tune, +the events are a good place to start debugging. For example, if you see `Unable to read training or validation dataset`, +you may need to make your files accessible to LLM Engine. If you see `Invalid value received for lora parameter 'lora_alpha'!`, +you should [check that your hyperparameters are valid](../../api/python_client/#llmengine.fine_tuning.FineTune.create). ## Making inference calls to your fine-tune -Once your fine-tune is finished, you will be able to start making inference requests to the -model. You can use the `fine_tuned_model` returned from your +Once your fine-tune is finished, you will be able to start making inference requests to the +model. You can use the `fine_tuned_model` returned from your [FineTune.get](../../api/python_client/#llmengine.fine_tuning.FineTune.get) -API call to reference your fine-tuned model in the Completions API. Alternatively, you can list -available LLMs with `Model.list` in order to find the name of your fine-tuned model. See the -[Completion API](../../api/python_client/#llmengine.Completion) for more details. You can then -use that name to direct your completion requests. You must wait until your fine-tune is complete -before you can plug it into the Completions API. You can check the status of your fine-tune with +API call to reference your fine-tuned model in the Completions API. Alternatively, you can list +available LLMs with `Model.list` in order to find the name of your fine-tuned model. See the +[Completion API](../../api/python_client/#llmengine.Completion) for more details. You can then +use that name to direct your completion requests. You must wait until your fine-tune is complete +before you can plug it into the Completions API. You can check the status of your fine-tune with [FineTune.get](../../api/python_client/#llmengine.fine_tuning.FineTune.get). === "Inference with a fine-tuned model in python" @@ -163,7 +200,7 @@ before you can plug it into the Completions API. You can check the status of you from llmengine import Completion response = Completion.create( - model="llama-7b.airlines.2023-07-17-08-30-45", + model="llama-2-7b.airlines.2023-07-17-08-30-45", prompt="Do you offer in-flight Wi-fi?", max_new_tokens=100, temperature=0.2, diff --git a/docs/guides/rate_limits.md b/docs/guides/rate_limits.md index 1224f19f..2aa59dd4 100644 --- a/docs/guides/rate_limits.md +++ b/docs/guides/rate_limits.md @@ -18,25 +18,26 @@ will return HTTP 429 on an as-needed basis. ## Retrying with exponential backoff -One easy way to avoid rate limit errors is to automatically retry requests with a random exponential backoff. -Retrying with exponential backoff means performing a short sleep when a rate limit error is hit, then retrying the -unsuccessful request. If the request is still unsuccessful, the sleep length is increased and the process is repeated. +One easy way to avoid rate limit errors is to automatically retry requests with a random exponential backoff. +Retrying with exponential backoff means performing a short sleep when a rate limit error is hit, then retrying the +unsuccessful request. If the request is still unsuccessful, the sleep length is increased and the process is repeated. This continues until the request is successful or until a maximum number of retries is reached. This approach has many benefits: -* Automatic retries means you can recover from rate limit errors without crashes or missing data -* Exponential backoff means that your first retries can be tried quickly, while still benefiting from longer delays if your first few retries fail -* Adding random jitter to the delay helps retries from all hitting at the same time. +- Automatic retries means you can recover from rate limit errors without crashes or missing data +- Exponential backoff means that your first retries can be tried quickly, while still benefiting from longer delays if your first few retries fail +- Adding random jitter to the delay helps retries from all hitting at the same time. Below are a few example solutions **for Python** that use exponential backoff. ### Example #1: Using the `tenacity` library -Tenacity is an Apache 2.0 licensed general-purpose retrying library, written in Python, to simplify the task of adding -retry behavior to just about anything. To add exponential backoff to your requests, you can use the tenacity.retry -decorator. The below example uses the tenacity.wait_random_exponential function to add random exponential backoff to a +Tenacity is an Apache 2.0 licensed general-purpose retrying library, written in Python, to simplify the task of adding +retry behavior to just about anything. To add exponential backoff to your requests, you can use the tenacity.retry +decorator. The below example uses the tenacity.wait_random_exponential function to add random exponential backoff to a request. === "Exponential backoff in python" + ```python import llmengine from tenacity import ( @@ -49,14 +50,15 @@ from tenacity import ( def completion_with_backoff(**kwargs): return llmengine.Completion.create(**kwargs) -completion_with_backoff(model="llama-7b", prompt="Why is the sky blue?") +completion_with_backoff(model="llama-2-7b", prompt="Why is the sky blue?") ``` ### Example #2: Using the `backoff` library -[Backoff](https://github.com/litl/backoff) is another python library that provides function decorators which can be used to wrap a function such that it will be retried until some condition is met. +[Backoff](https://github.com/litl/backoff) is another python library that provides function decorators which can be used to wrap a function such that it will be retried until some condition is met. === "Decorators for backoff and retry in python" + ```python import llmengine import backoff @@ -65,5 +67,5 @@ import backoff def completion_with_backoff(**kwargs): return llmengine.Completion.create(**kwargs) -completions_with_backoff(model="llama-7b", prompt="Why is the sky blue?") +completions_with_backoff(model="llama-2-7b", prompt="Why is the sky blue?") ``` diff --git a/docs/guides/self_hosting.md b/docs/guides/self_hosting.md index 39fa5bd4..442a16dc 100644 --- a/docs/guides/self_hosting.md +++ b/docs/guides/self_hosting.md @@ -1,7 +1,7 @@ -# [Experimental] Self Hosting +# Self Hosting _[Experimental]_ **This guide is currently highly experimental. Instructions are subject to change as we improve support for self-hosting.** -We provide a Helm chart that deploys LLM Engine to an [Elastic Kubernetes Cluster](https://aws.amazon.com/eks/). This Helm chart should be configured to connect to dependencies (such as a PostgreSQL database) that you may already have available in your environment. +We provide a Helm chart that deploys LLM Engine to an [Elastic Kubernetes Cluster](https://aws.amazon.com/eks/) (EKS) in [AWS](https://aws.amazon.com/). This Helm chart should be configured to connect to dependencies (such as a PostgreSQL database) that you may already have available in your environment. The only portions of the Helm chart that are production ready are the parts that configure and manage LLM Server itself (not PostgreSQL, IAM, etc.) @@ -21,8 +21,9 @@ Additionally, they must have the `k8s.amazonaws.com/accelerator` label set appro | --- | --- | | g4dn | nvidia-tesla-t4 | | g5 | nvidia-tesla-a10 | -| p4d | nvidia-tesla-a100 | -| p4de | nvidia-tesla-a100e | +| p4d | nvidia-ampere-a100 | +| p4de | nvidia-ampere-a100e | +| p5 | nvidia-hopper-h100 | We also recommend setting the following taint on your GPU nodes to prevent pods requiring GPU resources from being scheduled on them: - { key = "nvidia.com/gpu", value = "true", effect = "NO_SCHEDULE" } @@ -74,7 +75,7 @@ The LLM Engine server will an IAM role to perform various AWS operations. This r | `sqs:ListQueues` | `*` | | `ecr:BatchGetImage`, `ecr:DescribeImages`, `ecr:GetDownloadUrlForLayer`, `ecr:ListImages` | `${ecr_repository_arn}` | -# Helm Chart +## Helm Chart Now that all dependencies have been installed and configured, we can run the provided Helm chart. The values in the Helm chart will need to correspond with the resources described in the Dependencies section. Ensure that Helm V3 is installed [instructions](https://helm.sh/docs/intro/install/) and can connect to the EKS cluster. Users should be able to install the chart with `helm install llm-engine llm-engine -f llm-engine/values_sample.yaml -n `. @@ -112,6 +113,101 @@ Below are the configurations to specify in the `values_sample.yaml` file. | config.values.infra.redis_host | The hostname of the redis cluster you wish to connect | Yes | | config.values.infra.s3_bucket | The S3 bucket you wish to connect | Yes | | config.values.llm_engine.endpoint_namespace | K8s namespace the endpoints will be created in | Yes | -| config.values.llm_engine.cache_redis_url | The full url for the redis cluster you wish to connect | Yes | +| config.values.llm_engine.cache_redis_aws_url | The full url for the redis cluster you wish to connect | No | +| config.values.llm_engine.cache_redis_azure_host | The redis cluster host when using cloud_provider azure | No | | config.values.llm_engine.s3_file_llm_fine_tuning_job_repository | The S3 URI for the S3 bucket/key that you wish to save fine-tuned assets | Yes | -| config.values.datadog_trace_enabled | Whether to enable datadog tracing, datadog must be installed in the cluster | No | +| config.values.dd_trace_enabled | Whether to enable datadog tracing, datadog must be installed in the cluster | No | + +## Play With It +Once `helm install` succeeds, you can forward port `5000` from a `llm-engine` pod and test sending requests to it. + +First, see a list of pods in the namespace that you performed `helm install` in: +``` +$ kubectl get pods -n +NAME READY STATUS RESTARTS AGE +llm-engine-668679554-9q4wj 1/1 Running 0 18m +llm-engine-668679554-xfhxx 1/1 Running 0 18m +llm-engine-cacher-5f8b794585-fq7dj 1/1 Running 0 18m +llm-engine-endpoint-builder-5cd6bf5bbc-sm254 1/1 Running 0 18m +llm-engine-image-cache-a10-sw4pg 1/1 Running 0 18m +``` +Note the pod names you see may be different. + +Forward a port from a `llm-engine` pod: +``` +$ kubectl port-forward pod/llm-engine- 5000:5000 -n +``` + +Then, try sending a request to get LLM model endpoints for `test-user-id`: +``` +$ curl -X GET -H "Content-Type: application/json" -u "test-user-id:" "http://localhost:5000/v1/llm/model-endpoints" +``` + +You should get the following response: +``` +{"model_endpoints":[]} +``` + +Next, let's create a LLM endpoint using llama-7b: +``` +$ curl -X POST 'http://localhost:5000/v1/llm/model-endpoints' \ + -H 'Content-Type: application/json' \ + -d '{ + "name": "llama-7b", + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "text_generation_inference", + "inference_framework_image_tag": "0.9.3", + "num_shards": 4, + "endpoint_type": "streaming", + "cpus": 32, + "gpus": 4, + "memory": "40Gi", + "storage": "40Gi", + "gpu_type": "nvidia-ampere-a10", + "min_workers": 1, + "max_workers": 12, + "per_worker": 1, + "labels": {}, + "metadata": {} + }' \ + -u test_user_id: +``` + +It should output something like: +``` +{"endpoint_creation_task_id":"8d323344-b1b5-497d-a851-6d6284d2f8e4"} +``` + +Wait a few minutes for the endpoint to be ready. You can tell that it's ready by listing pods and checking that all containers in the llm endpoint pod are ready: +``` +$ kubectl get pods -n +NAME READY STATUS RESTARTS AGE +llm-engine-endpoint-id-end-cismpd08agn003rr2kc0-7f86ff64f9qj9xp 2/2 Running 1 (4m41s ago) 7m26s +``` +Note the endpoint name could be different. + +Then, you can send an inference request to the endppoint: +``` +$ curl -X POST 'http://localhost:5000/v1/llm/completions-sync?model_endpoint_name=llama-7b' \ + -H 'Content-Type: application/json' \ + -d '{ + "prompts": ["Tell me a joke about AI"], + "max_new_tokens": 30, + "temperature": 0.1 + }' \ + -u test-user-id: +``` + +You should get a response similar to: +``` +{"status":"SUCCESS","outputs":[{"text":". Tell me a joke about AI. Tell me a joke about AI. Tell me a joke about AI. Tell me","num_completion_tokens":30}],"traceback":null} +``` + +### Pointing LLM Engine client to use self-hosted infrastructure +The `llmengine` client makes requests to Scale AI's hosted infrastructure by default. You can have `llmengine` client make requests to your own self-hosted infrastructure by setting the `LLM_ENGINE_BASE_PATH` environment variable to the URL of the `llm-engine` service. + +The exact URL of `llm-engine` service depends on your Kubernetes cluster networking setup. The domain is specified at `config.values.infra.dns_host_domain` in the helm chart values config file. Using `charts/llm-engine/values_sample.yaml` as an example, you would do: +```bash +export LLM_ENGINE_BASE_PATH=https://llm-engine.domain.com +``` \ No newline at end of file diff --git a/docs/index.md b/docs/index.md index fcf8cf3e..01d4f84e 100644 --- a/docs/index.md +++ b/docs/index.md @@ -30,11 +30,11 @@ Kubernetes. ### Key Features **Ready-to-use APIs for your favorite models**: Deploy and serve -open source foundation models - including LLaMA, MPT, and Falcon. +open source foundation models - including Llama-2, MPT, and Falcon. Use Scale-hosted models or deploy to your own infrastructure. -**Fine-tune your favorite models**: Fine-tune open-source foundation -models like LLaMA, MPT, etc. with your own data for optimized performance. +**Fine-tune the best open-source models**: Fine-tune open-source foundation +models like Llama-2, MPT, etc. with your own data for optimized performance. **Optimized Inference**: LLM Engine provides inference APIs for streaming responses and dynamically batching inputs for higher throughput @@ -48,10 +48,11 @@ auto-scaling deployment with simple APIs. ### Features Coming Soon -**Kubernetes Installation Documentation**: We are working hard to document the installation and -maintenance of inference and fine-tuning functionality on your infrastructure. -For now, our documentation covers using our client libraries to access Scale's -hosted infrastructure. +**Kubernetes Installation Enhancements**: We are working hard to enhance the +installation and maintenance of inference and fine-tuning functionality on +your infrastructure. For now, our documentation covers _experimental_ libraries +to [deploy language models on your infrastructure](guides/self_hosting) +and libraries to access Scale's [hosted infrastructure](https://spellbook.scale.com). **Fast Cold-Start Times**: To prevent GPUs from idling, LLM Engine automatically scales your model to zero when it's not in use and scales up diff --git a/docs/integrations.md b/docs/integrations.md new file mode 100644 index 00000000..60674387 --- /dev/null +++ b/docs/integrations.md @@ -0,0 +1,24 @@ +# Integrations + +## Weights & Biases + +LLM Engine integrates with Weights & Biases to track metrics during fine tuning. To enable: + +```python +from llmengine import FineTune + +response = FineTune.create( + model="llama-2-7b", + training_file="s3://my-bucket/path/to/training-file.csv", + validation_file="s3://my-bucket/path/to/validation-file.csv", + hyperparameters={"report_to": "wandb"}, + wandb_config={"api_key":"key", "project":"fine-tune project"} +) +``` + +Configs to specify: + +- (Required) Set `hyperparameters.report_to` to `wandb` to enables automatic metrics tracking. +- (Required) Set `wandb_config.api_key` to the API key. +- (Optional) Set `wandb_config.base_url` to use a custom Weights & Biases server. +- `wandb_config` also accepts keys from [wandb.init()](https://docs.wandb.ai/ref/python/init). diff --git a/docs/model_zoo.md b/docs/model_zoo.md index f4e6875f..538f87e6 100644 --- a/docs/model_zoo.md +++ b/docs/model_zoo.md @@ -2,16 +2,57 @@ Scale hosts the following models in the LLM Engine Model Zoo: -| Model Name | Inference APIs Available | Fine-tuning APIs Available | -| --------------------- | ------------------------ | -------------------------- | -| `llama-7b` | ✅ | ✅ | -| `falcon-7b` | ✅ | | -| `falcon-7b-instruct` | ✅ | | -| `falcon-40b` | ✅ | | -| `falcon-40b-instruct` | ✅ | | -| `mpt-7b` | ✅ | | -| `mpt-7b-instruct` | ✅ | ✅ | -| `flan-t5-xxl` | ✅ | | +| Model Name | Inference APIs Available | Fine-tuning APIs Available | Inference Frameworks Available | Inference max total tokens (prompt + response) | +| --------------------------------- | ------------------------ | -------------------------- | ------------------------------------------ | ---------------------------------------------- | +| `llama-7b` | ✅ | ✅ | deepspeed, text-generation-inference | 2048 | +| `llama-2-7b` | ✅ | ✅ | text-generation-inference, vllm | 4096 | +| `llama-2-7b-chat` | ✅ | | text-generation-inference, vllm | 4096 | +| `llama-2-13b` | ✅ | | text-generation-inference, vllm | 4096 | +| `llama-2-13b-chat` | ✅ | | text-generation-inference, vllm | 4096 | +| `llama-2-70b` | ✅ | ✅ | text-generation-inference, vllm | 4096 | +| `llama-2-70b-chat` | ✅ | | text-generation-inference, vllm | 4096 | +| `llama-3-8b` | ✅ | | vllm | 8192 | +| `llama-3-8b-instruct` | ✅ | | vllm | 8192 | +| `llama-3-70b` | ✅ | | vllm | 8192 | +| `llama-3-70b-instruct` | ✅ | | vllm | 8192 | +| `llama-3-1-8b` | ✅ | | vllm | 131072 | +| `llama-3-1-8b-instruct` | ✅ | | vllm | 131072 | +| `llama-3-1-70b` | ✅ | | vllm | 131072 | +| `llama-3-1-70b-instruct` | ✅ | | vllm | 131072 | +| `falcon-7b` | ✅ | | text-generation-inference, vllm | 2048 | +| `falcon-7b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | +| `falcon-40b` | ✅ | | text-generation-inference, vllm | 2048 | +| `falcon-40b-instruct` | ✅ | | text-generation-inference, vllm | 2048 | +| `mpt-7b` | ✅ | | deepspeed, text-generation-inference, vllm | 2048 | +| `mpt-7b-instruct` | ✅ | ✅ | deepspeed, text-generation-inference, vllm | 2048 | +| `flan-t5-xxl` | ✅ | | deepspeed, text-generation-inference | 2048 | +| `mistral-7b` | ✅ | ✅ | vllm | 8000 | +| `mistral-7b-instruct` | ✅ | ✅ | vllm | 8000 | +| `mixtral-8x7b` | ✅ | | vllm | 32768 | +| `mixtral-8x7b-instruct` | ✅ | | vllm | 32768 | +| `mixtral-8x22b` | ✅ | | vllm | 65536 | +| `mixtral-8x22b-instruct` | ✅ | | vllm | 65536 | +| `codellama-7b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-7b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-13b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-13b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-34b` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-34b-instruct` | ✅ | ✅ | text-generation-inference, vllm | 16384 | +| `codellama-70b` | ✅ | | vllm | 16384 | +| `codellama-70b-instruct` | ✅ | | vllm | 4096 | +| `zephyr-7b-alpha` | ✅ | | text-generation-inference, vllm | 32768 | +| `zephyr-7b-beta` | ✅ | | text-generation-inference, vllm | 32768 | +| `gemma-2b` | ✅ | | vllm | 8192 | +| `gemma-2b-instruct` | ✅ | | vllm | 8192 | +| `gemma-7b` | ✅ | | vllm | 8192 | +| `gemma-7b-instruct` | ✅ | | vllm | 8192 | +| `phi-3-mini-4k-instruct` | ✅ | | vllm | 4096 | +| `deepseek-coder-v2` | ✅ | | vllm | 131072 | +| `deepseek-coder-v2-instruct` | ✅ | | vllm | 131072 | +| `deepseek-coder-v2-lite` | ✅ | | vllm | 131072 | +| `deepseek-coder-v2-lite-instruct` | ✅ | | vllm | 131072 | +| `qwen2-72b-instruct` | ✅ | | vllm | 32768 | + ## Usage diff --git a/docs/pricing.md b/docs/pricing.md index 1929cc8c..e1ffa2fe 100644 --- a/docs/pricing.md +++ b/docs/pricing.md @@ -1,15 +1,10 @@ # Pricing -LLM Engine is being offered initially as a free preview. LLM Engine is an open-source project and free self-hosting will always be an option. - -## Hosted Models - -Once the limited preview period has ended, billing for hosted models will be managed through Scale's [Spellbook](https://spellbook.scale.com/settings) product. - -Scale Spellbook leverages usage-based spending, billed to a credit card. - -Scale will share usage-based pricing before completing the limited preview to all users. +LLM Engine is an open-source project and free [self-hosting](../guides/self_hosting) will always be an option. As of October 31st 2024, +the free demo service is sunsetted. ## Self-Hosted Models -We are committed to supporting the open-source community. Self-hosting LLM Engine will remain free and open-source. +We are committed to supporting the open-source community. [Self-hosting](../guides/self_hosting) LLM Engine will remain free and open-source. + +We would love [contributions](../contributing) from the community make this even more amazing! diff --git a/examples/download_a_finetuned_model.ipynb b/examples/download_a_finetuned_model.ipynb new file mode 100644 index 00000000..da548b12 --- /dev/null +++ b/examples/download_a_finetuned_model.ipynb @@ -0,0 +1,348 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8d3a4214", + "metadata": { + "id": "8d3a4214" + }, + "source": [ + "# Download a FineTuned Model \n", + "This notebook demonstrates how to download a finetuned model that you've created using LLM Engine and add it to huggingface!\n", + "\n", + "**This notebook is an extension of the previous finetuning notebook on ScienceQA**" + ] + }, + { + "cell_type": "markdown", + "id": "XK6VpTnOL4OV", + "metadata": { + "id": "XK6VpTnOL4OV" + }, + "source": [ + "# Packages Required\n", + "For this demo, we'll be using the `scale-llm-engine` package, the `datasets` package for downloading our finetuning dataset, `transformers`, and `huggingface_hub` for uploading our model to huggingface.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "S5u6DdInMEQ7", + "metadata": { + "id": "S5u6DdInMEQ7" + }, + "outputs": [], + "source": [ + "!pip install scale-llm-engine\n", + "!pip install transformers\n", + "!pip install datasets" + ] + }, + { + "cell_type": "markdown", + "id": "a3dc2a56", + "metadata": { + "id": "a3dc2a56" + }, + "source": [ + "# Data Preparation\n", + "Let's load in the dataset using Huggingface and view the features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e06ac39e", + "metadata": { + "id": "e06ac39e" + }, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "from smart_open import smart_open\n", + "import pandas as pd\n", + "\n", + "dataset = load_dataset('derek-thomas/ScienceQA')\n", + "dataset['train'].features" + ] + }, + { + "cell_type": "markdown", + "id": "1cbe8a58", + "metadata": { + "id": "1cbe8a58" + }, + "source": [ + "Now, let's format the dataset into what's acceptable for LLM Engine - a CSV file with 'prompt' and 'response' columns." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b0eb8ad", + "metadata": { + "id": "0b0eb8ad" + }, + "outputs": [], + "source": [ + "choice_prefixes = [chr(ord('A') + i) for i in range(26)] # A-Z\n", + "def format_options(options, choice_prefixes):\n", + " return ' '.join([f'({c}) {o}' for c, o in zip(choice_prefixes, options)])\n", + "\n", + "def format_prompt(r, choice_prefixes):\n", + " options = format_options(r['choices'], choice_prefixes)\n", + " return f'''Context: {r[\"hint\"]}\\nQuestion: {r[\"question\"]}\\nOptions:{options}\\nAnswer:'''\n", + "\n", + "def format_label(r, choice_prefixes):\n", + " return choice_prefixes[r['answer']]\n", + "\n", + "def convert_dataset(ds):\n", + " prompts = [format_prompt(i, choice_prefixes) for i in ds if i['hint'] != '']\n", + " labels = [format_label(i, choice_prefixes) for i in ds if i['hint'] != '']\n", + " df = pd.DataFrame.from_dict({'prompt': prompts, 'response': labels})\n", + " return df\n", + "\n", + "save_to_s3 = False\n", + "df_train = convert_dataset(dataset['train'])\n", + "if save_to_s3:\n", + " train_url = 's3://...'\n", + " val_url = 's3://...'\n", + " df_train = convert_dataset(dataset['train'])\n", + " with smart_open(train_url, 'wb') as f:\n", + " df_train.to_csv(f)\n", + "\n", + " df_val = convert_dataset(dataset['validation'])\n", + " with smart_open(val_url, 'wb') as f:\n", + " df_val.to_csv(f)\n", + "else:\n", + " # Gists of the already processed datasets\n", + " train_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_train.csv'\n", + " val_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_val.csv'\n", + "\n", + "df_train" + ] + }, + { + "cell_type": "markdown", + "id": "e2fc8d76", + "metadata": { + "id": "e2fc8d76" + }, + "source": [ + "# Fine-tune\n", + "Now, we can fine-tune the model using LLM Engine." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4905d447", + "metadata": { + "id": "4905d447" + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ['SCALE_API_KEY'] = 'xxx'\n", + "\n", + "from llmengine import FineTune\n", + "\n", + "response = FineTune.create(\n", + " model=\"llama-2-7b\",\n", + " training_file=train_url,\n", + " validation_file=val_url,\n", + " hyperparameters={\n", + " 'lr':2e-4,\n", + " },\n", + " suffix='science-qa-llama'\n", + ")\n", + "run_id = response.id" + ] + }, + { + "cell_type": "markdown", + "id": "55074457", + "metadata": { + "id": "55074457" + }, + "source": [ + "We can sleep until the job completes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "840938dd", + "metadata": { + "id": "840938dd" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "while True:\n", + " job_status = FineTune.get(run_id).status\n", + " print(job_status)\n", + " if job_status == 'SUCCESS':\n", + " break\n", + " time.sleep(60)\n", + "\n", + "fine_tuned_model = FineTune.get(run_id).fine_tuned_model" + ] + }, + { + "cell_type": "markdown", + "id": "31278c6d", + "metadata": { + "id": "31278c6d" + }, + "source": [ + "# Downloading our Finetuned model \n", + "Let's download the weights for the new fine-tuned model using LLM Engine." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f2f3f43", + "metadata": { + "id": "9f2f3f43" + }, + "outputs": [], + "source": [ + "from llmengine import Model\n", + "\n", + "response = Model.download(FineTune.get(run_id).fine_tune_model, download_format=\"hugging_face\")\n", + "print(response.urls)" + ] + }, + { + "cell_type": "markdown", + "id": "ae9cbdf3", + "metadata": {}, + "source": [ + "We now have a dictionary of filenames and urls that point to the file(s) where our finetuned model lives. We can download the associated finetuned model either synchronously or asynchronously." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "dc363e48", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import requests\n", + "\n", + "def download_files(url_dict, directory):\n", + " \"\"\"\n", + " Download files from given URLs to specified directory.\n", + " \n", + " Parameters:\n", + " - url_dict: Dictionary of {file_name: url} pairs.\n", + " - directory: Directory to save the files.\n", + " \"\"\"\n", + " if not os.path.exists(directory):\n", + " os.makedirs(directory)\n", + " \n", + " for file_name, url in url_dict.items():\n", + " response = requests.get(url, stream=True)\n", + " response.raise_for_status() # Raise an exception for HTTP errors\n", + " file_path = os.path.join(directory, file_name)\n", + " \n", + " with open(file_path, 'wb') as file:\n", + " for chunk in response.iter_content(chunk_size=8192):\n", + " file.write(chunk)\n", + "\n", + " " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "000e1633", + "metadata": {}, + "outputs": [], + "source": [ + "output_directory = \"YOUR_MODEL_DIR\"\n", + "download_files(response.urls, output_directory) " + ] + }, + { + "cell_type": "markdown", + "id": "e4e87233", + "metadata": {}, + "source": [ + "Lastly, we can upload our downloaded model to the huggingface hub." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7c8ee18", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install huggingface-hub" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "328efd19", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "from huggingface_hub import Repository\n", + "\n", + "HF_USERNAME = \"YOUR_HUGGINGFACE_USERNAME\"\n", + "HF_TOKEN = \"YOUR_HUGGINGFACE_TOKEN\"\n", + "\n", + "def upload_to_huggingface(directory, model_name):\n", + " \"\"\"\n", + " Upload files from a directory to the Hugging Face Hub as a new model.\n", + "\n", + " Parameters:\n", + " - directory: Directory containing the files to be uploaded.\n", + " - model_name: Name of the new model.\n", + " - token: Your Hugging Face authentication token.\n", + " \"\"\"\n", + " \n", + " # Create a repository with the given name\n", + " repo = Repository(directory, clone_from=f\"{HF_USERNAME}/{model_name}\", use_auth_token=HF_TOKEN)\n", + " \n", + " # Commit and push files\n", + " repo.push_to_hub()\n", + "\n", + "model_name = \"my-new-model\"\n", + " \n", + "upload_to_huggingface(output_directory, model_name, HF_TOKEN)" + ] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Environment (conda_pytorch_p38)", + "language": "python", + "name": "conda_pytorch_p38" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/finetune_llama_2_on_science_qa.ipynb b/examples/finetune_llama_2_on_science_qa.ipynb new file mode 100644 index 00000000..dad7fe5e --- /dev/null +++ b/examples/finetune_llama_2_on_science_qa.ipynb @@ -0,0 +1,269 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "8d3a4214", + "metadata": { + "id": "8d3a4214" + }, + "source": [ + "# Finetune on ScienceQA\n", + "Let's use LLM Engine to fine-tune Llama-2 on ScienceQA!" + ] + }, + { + "cell_type": "markdown", + "id": "XK6VpTnOL4OV", + "metadata": { + "id": "XK6VpTnOL4OV" + }, + "source": [ + "# Packages Required\n", + "For this demo, we'll be using the `scale-llm-engine` package and `datasets` from Huggingface.\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "S5u6DdInMEQ7", + "metadata": { + "id": "S5u6DdInMEQ7" + }, + "outputs": [], + "source": [ + "!pip install scale-llm-engine\n", + "!pip install datasets" + ] + }, + { + "cell_type": "markdown", + "id": "a3dc2a56", + "metadata": { + "id": "a3dc2a56" + }, + "source": [ + "# Data Preparation\n", + "Let's load in the dataset using Huggingface and view the features." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e06ac39e", + "metadata": { + "id": "e06ac39e" + }, + "outputs": [], + "source": [ + "from datasets import load_dataset\n", + "from smart_open import smart_open\n", + "\n", + "dataset = load_dataset('derek-thomas/ScienceQA')\n", + "dataset['train'].features" + ] + }, + { + "cell_type": "markdown", + "id": "1cbe8a58", + "metadata": { + "id": "1cbe8a58" + }, + "source": [ + "Now, let's format the dataset into what's acceptable for LLM Engine - a CSV file with 'prompt' and 'response' columns." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0b0eb8ad", + "metadata": { + "id": "0b0eb8ad" + }, + "outputs": [], + "source": [ + "choice_prefixes = [chr(ord('A') + i) for i in range(26)] # A-Z\n", + "def format_options(options, choice_prefixes):\n", + " return ' '.join([f'({c}) {o}' for c, o in zip(choice_prefixes, options)])\n", + "\n", + "def format_prompt(r, choice_prefixes):\n", + " options = format_options(r['choices'], choice_prefixes)\n", + " return f'''Context: {r[\"hint\"]}\\nQuestion: {r[\"question\"]}\\nOptions:{options}\\nAnswer:'''\n", + "\n", + "def format_label(r, choice_prefixes):\n", + " return choice_prefixes[r['answer']]\n", + "\n", + "def convert_dataset(ds):\n", + " prompts = [format_prompt(i, choice_prefixes) for i in ds if i['hint'] != '']\n", + " labels = [format_label(i, choice_prefixes) for i in ds if i['hint'] != '']\n", + " df = pd.DataFrame.from_dict({'prompt': prompts, 'response': labels})\n", + " return df\n", + "\n", + "save_to_s3 = False\n", + "df_train = convert_dataset(dataset['train'])\n", + "if save_to_s3:\n", + " train_url = 's3://...'\n", + " val_url = 's3://...'\n", + " df_train = convert_dataset(dataset['train'])\n", + " with smart_open(train_url, 'wb') as f:\n", + " df_train.to_csv(f)\n", + "\n", + " df_val = convert_dataset(dataset['validation'])\n", + " with smart_open(val_url, 'wb') as f:\n", + " df_val.to_csv(f)\n", + "else:\n", + " # Gists of the already processed datasets\n", + " train_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_train.csv'\n", + " val_url = 'https://gist.githubusercontent.com/jihan-yin/43f19a86d35bf22fa3551d2806e478ec/raw/91416c09f09d3fca974f81d1f766dd4cadb29789/scienceqa_val.csv'\n", + "\n", + "df_train" + ] + }, + { + "cell_type": "markdown", + "id": "e2fc8d76", + "metadata": { + "id": "e2fc8d76" + }, + "source": [ + "# Fine-tune\n", + "Now, we can fine-tune the model using LLM Engine." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4905d447", + "metadata": { + "id": "4905d447" + }, + "outputs": [], + "source": [ + "import os\n", + "os.environ['SCALE_API_KEY'] = 'xxx'\n", + "\n", + "from llmengine import FineTune\n", + "\n", + "response = FineTune.create(\n", + " model=\"llama-2-7b\",\n", + " training_file=train_url,\n", + " validation_file=val_url,\n", + " hyperparameters={\n", + " 'lr':2e-4,\n", + " },\n", + " suffix='science-qa-llama'\n", + ")\n", + "run_id = response.id" + ] + }, + { + "cell_type": "markdown", + "id": "55074457", + "metadata": { + "id": "55074457" + }, + "source": [ + "We can sleep until the job completes." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "840938dd", + "metadata": { + "id": "840938dd" + }, + "outputs": [], + "source": [ + "import time\n", + "\n", + "while True:\n", + " job_status = FineTune.get(run_id).status\n", + " print(job_status)\n", + " if job_status == 'SUCCESS':\n", + " break\n", + " time.sleep(60)\n", + "\n", + "fine_tuned_model = FineTune.get(run_id).fine_tuned_model" + ] + }, + { + "cell_type": "markdown", + "id": "31278c6d", + "metadata": { + "id": "31278c6d" + }, + "source": [ + "# Inference and Evaluation\n", + "Let's evaluate the new fine-tuned model by running inference against it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3b9d7643", + "metadata": { + "id": "3b9d7643" + }, + "outputs": [], + "source": [ + "import pandas as pd\n", + "from llmengine import Completion\n", + "\n", + "# Helper function to get outputs for fine-tuned model with retries\n", + "def get_output(prompt: str, num_retry: int = 5):\n", + " for _ in range(num_retry):\n", + " try:\n", + " response = Completion.create(\n", + " model=fine_tuned_model,\n", + " prompt=prompt,\n", + " max_new_tokens=1,\n", + " temperature=0.01\n", + " )\n", + " return response.output.text.strip()\n", + " except Exception as e:\n", + " print(e)\n", + " return \"\"\n", + "\n", + "# Read the test data\n", + "test = pd.read_csv(val_url)\n", + "\n", + "test[\"prediction\"] = test[\"prompt\"].apply(get_output)\n", + "print(f\"Accuracy: {(test['response'] == test['prediction']).mean() * 100:.2f}%\")" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9f2f3f43", + "metadata": { + "id": "9f2f3f43" + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "colab": { + "provenance": [] + }, + "kernelspec": { + "display_name": "Environment (conda_pytorch_p38)", + "language": "python", + "name": "conda_pytorch_p38" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.12" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/server/llm_engine_server/common/dtos/__init__.py b/integration_tests/__init__.py similarity index 100% rename from server/llm_engine_server/common/dtos/__init__.py rename to integration_tests/__init__.py diff --git a/integration_tests/rest_api_utils.py b/integration_tests/rest_api_utils.py new file mode 100644 index 00000000..7db937dc --- /dev/null +++ b/integration_tests/rest_api_utils.py @@ -0,0 +1,1077 @@ +import asyncio +import inspect +import json +import os +import re +import time +from typing import Any, Dict, List, Optional, Sequence + +import aiohttp +import requests +from model_engine_server.common.dtos.tasks import TaskStatus +from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed + +_DEFAULT_BASE_PATH = "http://localhost:5001" +BASE_PATH = os.environ.get("BASE_PATH", _DEFAULT_BASE_PATH) +print(f"Integration tests using gateway {BASE_PATH=}") +DEFAULT_NETWORK_TIMEOUT_SEC = 10 +LONG_NETWORK_TIMEOUT_SEC = 30 + +# add suffix to avoid name collisions +SERVICE_IDENTIFIER = os.environ.get("SERVICE_IDENTIFIER", "") + + +def format_name(name: str) -> str: + return f"{name}-{SERVICE_IDENTIFIER}" if SERVICE_IDENTIFIER else name + + +# Use the scale-launch-integration-tests id +USER_ID_0 = os.getenv("TEST_USER_ID", "fakeuser") + +DEFAULT_USERS: Sequence[str] = (USER_ID_0,) # type: ignore + + +def echo_load_predict_fn(model): + def echo(**keyword_args): + return model(**keyword_args) + + return echo + + +def echo_load_model_fn(): + def my_model(**keyword_args): + return {k: v for k, v in keyword_args.items()} + + return my_model + + +CREATE_MODEL_BUNDLE_REQUEST_SIMPLE = { + "name": "model_bundle_simple", + "schema_location": "s3://model-engine-integration-tests/model_bundles/echo_schemas", + "metadata": { + "test_key": "test_value", + }, + "flavor": { + "flavor": "cloudpickle_artifact", + "load_predict_fn": inspect.getsource(echo_load_predict_fn), + "load_model_fn": inspect.getsource(echo_load_model_fn), + "framework": { + "framework_type": "pytorch", + "pytorch_image_tag": "1.11.0-cuda11.3-cudnn8-runtime", + }, + "requirements": [ + "cloudpickle==2.1.0", + "pyyaml==6.0", + "pydantic==2.8.2", + "fastapi==0.110.0", + ], + "location": "s3://model-engine-integration-tests/model_bundles/echo_bundle", + }, +} + +CREATE_MODEL_BUNDLE_REQUEST_RUNNABLE_IMAGE = { + "name": "model_bundle_runnable_image", + "schema_location": "s3://model-engine-integration-tests/model_bundles/echo_schemas", + "metadata": { + "test_key": "test_value", + }, + "flavor": { + "flavor": "streaming_enhanced_runnable_image", + "repository": "model-engine", + "tag": "830c81ecba2a147022e504917c6ce18b00c2af44", + "command": [ + "dumb-init", + "--", + "ddtrace-run", + "python", + "-m", + "model_engine_server.inference.forwarding.echo_server", + "--port", + "5005", + ], + "streaming_command": [ + "dumb-init", + "--", + "ddtrace-run", + "python", + "-m", + "model_engine_server.inference.forwarding.echo_server", + "--port", + "5005", + ], + "env": { + "TEST_KEY": "test_value", + "ML_INFRA_SERVICES_CONFIG_PATH": "/workspace/model-engine/model_engine_server/core/configs/default.yaml", + # infra configs are mounted here + "HTTP_HOST": "0.0.0.0", # Hack for uvicorn to work in minikube + }, + "protocol": "http", + "readiness_initial_delay_seconds": 20, + }, +} + +CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_SIMPLE = { + "bundle_name": "model_bundle_simple", + "name": format_name("model-endpoint-simple-async"), + "endpoint_type": "async", + "cpus": "0.5", + "memory": "500Mi", + "storage": "1Gi", + "min_workers": 1, + "max_workers": 1, + "gpus": 0, + "per_worker": 1, + "labels": {"team": "infra", "product": "launch"}, + "metadata": {}, +} + +CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE = CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_SIMPLE.copy() +CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE["name"] = format_name("model-endpoint-simple-sync") +CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE["endpoint_type"] = "sync" + +CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE = { + "bundle_name": "model_bundle_runnable_image", + "name": format_name("model-endpoint-runnable-async"), + "post_inference_hooks": [], + "endpoint_type": "async", + "cpus": "1", + "gpus": 0, + "memory": "1Gi", + "storage": "2Gi", + "optimize_costs": False, + "min_workers": 1, + "max_workers": 1, + "per_worker": 1, + "labels": {"team": "infra", "product": "launch"}, + "metadata": {"key": "value"}, +} + +CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE = ( + CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE.copy() +) +CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE["name"] = format_name( + "model-endpoint-runnable-sync-streaming" +) +CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE["endpoint_type"] = "streaming" + +UPDATE_MODEL_ENDPOINT_REQUEST_SIMPLE = { + "bundle_name": "model_bundle_simple", + "cpus": "1", + "memory": "1Gi", + "max_workers": 2, +} + +UPDATE_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE = { + "bundle_name": "model_bundle_runnable_image", + "cpus": "2", + "memory": "2Gi", + "max_workers": 2, +} + +INFERENCE_PAYLOAD: Dict[str, Any] = { + "args": {"y": 1}, + "url": None, +} + +CREATE_LLM_MODEL_ENDPOINT_REQUEST: Dict[str, Any] = { + "name": format_name("llama-2-7b-test"), + "model_name": "llama-2-7b-chat", + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "latest", + "endpoint_type": "streaming", + "cpus": 20, + "gpus": 1, + "memory": "20Gi", + "gpu_type": "nvidia-hopper-h100-1g20gb", + "storage": "40Gi", + "optimize_costs": False, + "min_workers": 1, + "max_workers": 1, + "per_worker": 1, + "labels": {"team": "infra", "product": "launch"}, + "metadata": {"key": "value"}, + "public_inference": False, +} + + +INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE: Dict[str, Any] = INFERENCE_PAYLOAD.copy() +INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE["return_pickled"] = False + +INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE: Dict[str, Any] = INFERENCE_PAYLOAD.copy() +INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE["return_pickled"] = True + +LLM_PAYLOAD: Dict[str, Any] = { + "prompt": "Hello, my name is", + "max_new_tokens": 10, + "temperature": 0.2, +} + +LLM_PAYLOAD_WITH_STOP_SEQUENCE: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_STOP_SEQUENCE["stop_sequences"] = ["\n"] + +LLM_PAYLOAD_WITH_PRESENCE_PENALTY: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_PRESENCE_PENALTY["presence_penalty"] = 0.5 + +LLM_PAYLOAD_WITH_FREQUENCY_PENALTY: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_FREQUENCY_PENALTY["frequency_penalty"] = 0.5 + +LLM_PAYLOAD_WITH_TOP_K: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_TOP_K["top_k"] = 10 + +LLM_PAYLOAD_WITH_TOP_P: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_TOP_P["top_p"] = 0.5 + +LLM_PAYLOAD_WITH_INCLUDE_STOP_STR_IN_OUTPUT: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_INCLUDE_STOP_STR_IN_OUTPUT["include_stop_str_in_output"] = True + +LLM_PAYLOAD_WITH_GUIDED_JSON: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_GUIDED_JSON["guided_json"] = { + "properties": {"myString": {"type": "string"}}, + "required": ["myString"], +} + +LLM_PAYLOAD_WITH_GUIDED_REGEX: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_GUIDED_REGEX["guided_regex"] = "Sean.*" + +LLM_PAYLOAD_WITH_GUIDED_CHOICE: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_GUIDED_CHOICE["guided_choice"] = ["dog", "cat"] + +LLM_PAYLOAD_WITH_GUIDED_GRAMMAR: Dict[str, Any] = LLM_PAYLOAD.copy() +LLM_PAYLOAD_WITH_GUIDED_GRAMMAR["guided_grammar"] = 'start: "John"' + +LLM_PAYLOADS_WITH_EXPECTED_RESPONSES = [ + (LLM_PAYLOAD, None, None), + (LLM_PAYLOAD_WITH_STOP_SEQUENCE, None, None), + (LLM_PAYLOAD_WITH_PRESENCE_PENALTY, None, None), + (LLM_PAYLOAD_WITH_FREQUENCY_PENALTY, None, None), + (LLM_PAYLOAD_WITH_TOP_K, None, None), + (LLM_PAYLOAD_WITH_TOP_P, None, None), + (LLM_PAYLOAD_WITH_INCLUDE_STOP_STR_IN_OUTPUT, ["tokens"], None), + (LLM_PAYLOAD_WITH_GUIDED_JSON, None, None), + (LLM_PAYLOAD_WITH_GUIDED_REGEX, None, "Sean.*"), + (LLM_PAYLOAD_WITH_GUIDED_CHOICE, None, "dog|cat"), + (LLM_PAYLOAD_WITH_GUIDED_GRAMMAR, None, "John"), +] + +CREATE_BATCH_JOB_REQUEST: Dict[str, Any] = { + "bundle_name": "model_bundle_simple", + "input_path": "TBA", + "serialization_format": "JSON", + "labels": {"team": "infra", "product": "launch"}, + "resource_requests": { + "memory": "500Mi", + "max_workers": 1, + "gpus": 0, + }, +} + +CREATE_DOCKER_IMAGE_BATCH_JOB_BUNDLE_REQUEST: Dict[str, Any] = { + "name": format_name("di_batch_job_bundle_1"), + "image_repository": "model-engine", + "image_tag": "830c81ecba2a147022e504917c6ce18b00c2af44", + "command": ["jq", ".", "/launch_mount_location/file"], + "env": {"ENV1": "VAL1"}, + "mount_location": "/launch_mount_location/file", + "resource_requests": { + "cpus": 0.1, + "memory": "10Mi", + }, +} + +CREATE_DOCKER_IMAGE_BATCH_JOB_REQUEST: Dict[str, Any] = { + "docker_image_batch_job_bundle_name": format_name("di_batch_job_bundle_1"), + "job_config": {"data": {"to": "mount"}}, + "labels": {"team": "infra", "product": "testing"}, + "resource_requests": {"cpus": 0.15, "memory": "15Mi"}, +} + +CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST: Dict[str, Any] = { + "name": format_name("fine_tune_di_batch_job_bundle_1"), + "image_repository": "model-engine", + "image_tag": "830c81ecba2a147022e504917c6ce18b00c2af44", + "command": ["cat", "/launch_mount_location/file"], + "env": {"ENV1": "VAL1"}, + "mount_location": "/launch_mount_location/file", + "resource_requests": { + "cpus": 0.1, + "memory": "10Mi", + }, + "public": True, +} + +CREATE_FINE_TUNE_REQUEST: Dict[str, Any] = { + "model": "test_base_model", + "training_file": "s3://model-engine-integration-tests/fine_tune_files/run_through_walls.csv", + "validation_file": None, + # "fine_tuning_method": "test_fine_tuning_method", # ignored until we change it + "hyperparameters": {}, +} + + +@retry(stop=stop_after_attempt(300), wait=wait_fixed(2)) +def ensure_launch_gateway_healthy(): + assert requests.get(f"{BASE_PATH}/healthz").status_code == 200 + + +def create_model_bundle( + create_model_bundle_request: Dict[str, Any], user_id: str, version: str +) -> Dict[str, Any]: + response = requests.post( + f"{BASE_PATH}/{version}/model-bundles", + json=create_model_bundle_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +@retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) +def get_latest_model_bundle(model_name: str, user_id: str, version: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/{version}/model-bundles/latest?model_name={model_name}", + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_or_create_model_bundle( + create_model_bundle_request: Dict[str, Any], user_id: str, version: str +) -> Dict[str, Any]: + # In v1, we will no longer have the uniqueness constraint of (name, created_by) but right now + # for backwards compatibility, such a constraint exists. As a result, we use this get-or-create + # method as a temporary workaround since v1 will not support bundle deletion initially. + try: + return get_latest_model_bundle(create_model_bundle_request["name"], user_id, version) + except: # noqa: E722 + return create_model_bundle(create_model_bundle_request, user_id, version) + + +def replace_model_bundle_name_with_id(request: Dict[str, Any], user_id: str, version): + if "bundle_name" in request: + model_bundle = get_latest_model_bundle(request["bundle_name"], user_id, version) + request["model_bundle_id"] = model_bundle["id"] + del request["bundle_name"] + + +def create_model_endpoint( + create_model_endpoint_request: Dict[str, Any], user_id: str +) -> Dict[str, Any]: + create_model_endpoint_request = create_model_endpoint_request.copy() + replace_model_bundle_name_with_id(create_model_endpoint_request, user_id, "v1") + response = requests.post( + f"{BASE_PATH}/v1/model-endpoints", + json=create_model_endpoint_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def create_batch_job(create_batch_job_request: Dict[str, Any], user_id: str) -> Dict[str, Any]: + create_batch_job_request = create_batch_job_request.copy() + replace_model_bundle_name_with_id(create_batch_job_request, user_id, "v2") + response = requests.post( + f"{BASE_PATH}/v1/batch-jobs", + json=create_batch_job_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def cancel_batch_job(batch_job_id: str, user_id: str) -> Dict[str, Any]: + response = requests.put( + f"{BASE_PATH}/v1/batch-jobs/{batch_job_id}", + json={"cancel": True}, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def create_docker_image_batch_job_bundle( + create_docker_image_batch_job_bundle_request: Dict[str, Any], user_id: str +) -> Dict[str, Any]: + response = requests.post( + f"{BASE_PATH}/v1/docker-image-batch-job-bundles", + json=create_docker_image_batch_job_bundle_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_latest_docker_image_batch_job_bundle(bundle_name: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/docker-image-batch-job-bundles/latest?bundle_name={bundle_name}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_or_create_docker_image_batch_job_bundle( + create_docker_image_batch_job_bundle_request: Dict[str, Any], user_id: str +): + try: + return get_latest_docker_image_batch_job_bundle( + create_docker_image_batch_job_bundle_request["name"], user_id + ) + except: # noqa: E722 + return create_docker_image_batch_job_bundle( + create_docker_image_batch_job_bundle_request, user_id + ) + + +def get_docker_image_batch_job_bundle_by_id( + docker_image_batch_job_bundle_id: str, user_id: str +) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/docker-image-batch-job-bundles/{docker_image_batch_job_bundle_id}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def create_docker_image_batch_job( + create_docker_image_batch_job_request: Dict[str, Any], user_id: str +) -> Dict[str, Any]: + response = requests.post( + f"{BASE_PATH}/v1/docker-image-batch-jobs", + json=create_docker_image_batch_job_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_docker_image_batch_job(batch_job_id: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/docker-image-batch-jobs/{batch_job_id}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def create_fine_tune(create_fine_tune_request: Dict[str, Any], user_id: str) -> Dict[str, Any]: + response = requests.post( + f"{BASE_PATH}/v1/llm/fine-tunes", + json=create_fine_tune_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_fine_tune_by_id(fine_tune_id: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/llm/fine-tunes/{fine_tune_id}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def list_fine_tunes(user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/llm/fine-tunes", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def cancel_fine_tune_by_id(fine_tune_id: str, user_id: str) -> Dict[str, Any]: + response = requests.put( + f"{BASE_PATH}/v1/llm/fine-tunes/{fine_tune_id}/cancel", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def upload_file(file, user_id: str) -> Dict[str, Any]: + files = {"file": file} + response = requests.post( + f"{BASE_PATH}/v1/files", + files=files, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_file_by_id(file_id: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/files/{file_id}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def list_files(user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/files", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=30, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def delete_file_by_id(file_id: str, user_id: str) -> Dict[str, Any]: + response = requests.delete( + f"{BASE_PATH}/v1/files/{file_id}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def get_file_content_by_id(file_id: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/files/{file_id}/content", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +@retry(stop=stop_after_attempt(6), wait=wait_fixed(1)) +def get_model_endpoint(name: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/model-endpoints?name={name}", + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json()["model_endpoints"][0] + + +@retry(stop=stop_after_attempt(6), wait=wait_fixed(1)) +def get_llm_model_endpoint(name: str, user_id: str) -> Dict[str, Any]: + response = requests.get( + f"{BASE_PATH}/v1/llm/model-endpoints/{name}", + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +@retry(stop=stop_after_attempt(3), wait=wait_fixed(20)) +def update_model_endpoint( + endpoint_name: str, update_model_endpoint_request: Dict[str, Any], user_id: str +) -> Dict[str, Any]: + update_model_endpoint_request = update_model_endpoint_request.copy() + replace_model_bundle_name_with_id(update_model_endpoint_request, user_id, "v2") + endpoint = get_model_endpoint(endpoint_name, user_id) + response = requests.put( + f"{BASE_PATH}/v1/model-endpoints/{endpoint['id']}", + json=update_model_endpoint_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def delete_model_endpoint(endpoint_name: str, user_id: str) -> Dict[str, Any]: + endpoint = get_model_endpoint(endpoint_name, user_id) + response = requests.delete( + f"{BASE_PATH}/v1/model-endpoints/{endpoint['id']}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +def delete_llm_model_endpoint(endpoint_name: str, user_id: str) -> Dict[str, Any]: + response = requests.delete( + f"{BASE_PATH}/v1/llm/model-endpoints/{endpoint_name}", + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +@retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) +def list_model_endpoints(user_id: str) -> List[Dict[str, Any]]: + response = requests.get( + f"{BASE_PATH}/v1/model-endpoints", + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json()["model_endpoints"] + + +@retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) +def list_llm_model_endpoints(user_id: str) -> List[Dict[str, Any]]: + response = requests.get( + f"{BASE_PATH}/v1/llm/model-endpoints", + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json()["model_endpoints"] + + +@retry(stop=stop_after_attempt(3), wait=wait_fixed(1)) +def create_llm_model_endpoint( + create_llm_model_endpoint_request: Dict[str, Any], + user_id: str, + inference_framework: Optional[str], + inference_framework_image_tag: Optional[str], +) -> Dict[str, Any]: + create_model_endpoint_request = create_llm_model_endpoint_request.copy() + if inference_framework: + create_model_endpoint_request["inference_framework"] = inference_framework + if inference_framework_image_tag: + create_model_endpoint_request["inference_framework_image_tag"] = ( + inference_framework_image_tag + ) + response = requests.post( + f"{BASE_PATH}/v1/llm/model-endpoints", + json=create_model_endpoint_request, + headers={"Content-Type": "application/json"}, + auth=(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) + if not response.ok: + raise ValueError(response.content) + return response.json() + + +async def create_async_task( + model_endpoint_id: str, + create_async_task_request: Dict[str, Any], + user_id: str, + session: aiohttp.ClientSession, +) -> str: + async with session.post( + f"{BASE_PATH}/v1/async-tasks?model_endpoint_id={model_endpoint_id}", + json=create_async_task_request, + headers={"Content-Type": "application/json"}, + auth=aiohttp.BasicAuth(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) as response: + return (await response.json())["task_id"] + + +async def create_async_tasks( + endpoint_name: str, create_async_task_requests: List[Dict[str, Any]], user_id: str +) -> List[Any]: + endpoint = get_model_endpoint(endpoint_name, user_id) + async with aiohttp.ClientSession() as session: + tasks = [] + for create_async_task_request in create_async_task_requests: + task = create_async_task(endpoint["id"], create_async_task_request, user_id, session) + tasks.append(asyncio.create_task(task)) + + result = await asyncio.gather(*tasks) + return result # type: ignore + + +async def create_sync_task( + model_endpoint_id: str, + create_sync_task_request: Dict[str, Any], + user_id: str, + session: aiohttp.ClientSession, +) -> str: + async with session.post( + f"{BASE_PATH}/v1/sync-tasks?model_endpoint_id={model_endpoint_id}", + json=create_sync_task_request, + headers={"Content-Type": "application/json"}, + auth=aiohttp.BasicAuth(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) as response: + assert response.status == 200, (await response.read()).decode() + return await response.json() + + +async def create_llm_sync_task( + model_endpoint_name: str, + create_sync_task_request: Dict[str, Any], + user_id: str, + session: aiohttp.ClientSession, +) -> str: + async with session.post( + f"{BASE_PATH}/v1/llm/completions-sync?model_endpoint_name={model_endpoint_name}", + json=create_sync_task_request, + headers={"Content-Type": "application/json"}, + auth=aiohttp.BasicAuth(user_id, ""), + timeout=LONG_NETWORK_TIMEOUT_SEC, + ) as response: + assert response.status == 200, (await response.read()).decode() + return await response.json() + + +async def create_streaming_task( + model_endpoint_id: str, + create_streaming_task_request: Dict[str, Any], + user_id: str, + session: aiohttp.ClientSession, +) -> str: + async with session.post( + f"{BASE_PATH}/v1/streaming-tasks?model_endpoint_id={model_endpoint_id}", + json=create_streaming_task_request, + headers={"Content-Type": "application/json"}, + auth=aiohttp.BasicAuth(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) as response: + assert response.status == 200, (await response.read()).decode() + return (await response.read()).decode() + + +async def create_llm_streaming_task( + model_endpoint_name: str, + create_streaming_task_request: Dict[str, Any], + user_id: str, + session: aiohttp.ClientSession, +) -> str: + async with session.post( + f"{BASE_PATH}/v1/llm/completions-stream?model_endpoint_name={model_endpoint_name}", + json=create_streaming_task_request, + headers={"Content-Type": "application/json"}, + auth=aiohttp.BasicAuth(user_id, ""), + timeout=LONG_NETWORK_TIMEOUT_SEC, + ) as response: + assert response.status == 200, (await response.read()).decode() + return (await response.read()).decode() + + +async def create_sync_tasks( + endpoint_name: str, create_sync_task_requests: List[Dict[str, Any]], user_id: str +) -> List[Any]: + endpoint = get_model_endpoint(endpoint_name, user_id) + async with aiohttp.ClientSession() as session: + tasks = [] + for create_sync_task_request in create_sync_task_requests: + task = create_sync_task(endpoint["id"], create_sync_task_request, user_id, session) + tasks.append(asyncio.create_task(task)) + + result = await asyncio.gather(*tasks) + return result # type: ignore + + +async def create_llm_sync_tasks( + endpoint_name: str, create_sync_task_requests: List[Dict[str, Any]], user_id: str +) -> List[Any]: + async with aiohttp.ClientSession() as session: + tasks = [] + for create_sync_task_request in create_sync_task_requests: + task = create_llm_sync_task(endpoint_name, create_sync_task_request, user_id, session) + tasks.append(asyncio.create_task(task)) + + result = await asyncio.gather(*tasks) + return result # type: ignore + + +async def create_streaming_tasks( + endpoint_name: str, create_streaming_task_requests: List[Dict[str, Any]], user_id: str +) -> List[Any]: + endpoint = get_model_endpoint(endpoint_name, user_id) + async with aiohttp.ClientSession() as session: + tasks = [] + for create_streaming_task_request in create_streaming_task_requests: + task = create_streaming_task( + endpoint["id"], create_streaming_task_request, user_id, session + ) + tasks.append(asyncio.create_task(task)) + + result = await asyncio.gather(*tasks) + return result # type: ignore + + +async def create_llm_streaming_tasks( + endpoint_name: str, create_streaming_task_requests: List[Dict[str, Any]], user_id: str +) -> List[Any]: + async with aiohttp.ClientSession() as session: + tasks = [] + for create_streaming_task_request in create_streaming_task_requests: + task = create_llm_streaming_task( + endpoint_name, create_streaming_task_request, user_id, session + ) + tasks.append(asyncio.create_task(task)) + + result = await asyncio.gather(*tasks) + return result # type: ignore + + +async def get_async_task( + task_id: str, user_id: str, session: aiohttp.ClientSession +) -> Dict[str, Any]: + async with session.get( + f"{BASE_PATH}/v1/async-tasks/{task_id}", + auth=aiohttp.BasicAuth(user_id, ""), + timeout=DEFAULT_NETWORK_TIMEOUT_SEC, + ) as response: + return await response.json() + + +async def get_async_tasks(task_ids: List[str], user_id: str) -> List[Dict[str, Any]]: + async with aiohttp.ClientSession() as session: + tasks = [] + for task_id in task_ids: + task = get_async_task(task_id, user_id, session) + tasks.append(asyncio.create_task(task)) + + result = await asyncio.gather(*tasks) + return result # type: ignore + + +# Wait 25 minutes (1500 seconds) for endpoints to build. +@retry(stop=stop_after_attempt(25), wait=wait_fixed(60)) +def ensure_n_ready_endpoints_long(n: int, user_id: str): + endpoints = list_model_endpoints(user_id) + ready_endpoints = [endpoint for endpoint in endpoints if endpoint["status"] == "READY"] + print( + f"User {user_id} Current num endpoints: {len(endpoints)}, num ready endpoints: {len(ready_endpoints)}" + ) + assert ( + len(ready_endpoints) >= n + ), f"Expected {n} ready endpoints, got {len(ready_endpoints)}. Look through endpoint builder for errors." + + +# Wait 2 minutes (120 seconds) for endpoints to build. +@retry(stop=stop_after_attempt(12), wait=wait_fixed(10)) +def ensure_n_ready_endpoints_short(n: int, user_id: str): + endpoints = list_model_endpoints(user_id) + ready_endpoints = [endpoint for endpoint in endpoints if endpoint["status"] == "READY"] + print( + f"User {user_id} Current num endpoints: {len(endpoints)}, num ready endpoints: {len(ready_endpoints)}" + ) + assert len(ready_endpoints) >= n + + +# Wait 2 minutes (120 seconds) for endpoints to build. +@retry(stop=stop_after_attempt(12), wait=wait_fixed(10)) +def ensure_n_ready_private_llm_endpoints_short(n: int, user_id: str): + endpoints = list_llm_model_endpoints(user_id) + private_endpoints = [ + endpoint for endpoint in endpoints if not endpoint["spec"]["public_inference"] + ] + ready_endpoints = [endpoint for endpoint in private_endpoints if endpoint["status"] == "READY"] + print( + f"User {user_id} Current num endpoints: {len(private_endpoints)}, num ready endpoints: {len(ready_endpoints)}" + ) + assert ( + len(ready_endpoints) >= n + ), f"Expected {n} ready endpoints, got {len(ready_endpoints)}. Look through endpoint builder for errors." + + +def delete_all_endpoints(user_id: str, delete_suffix_only: bool): + endpoints = list_model_endpoints(user_id) + for i, endpoint in enumerate(endpoints): + if ( + delete_suffix_only + and SERVICE_IDENTIFIER + and not endpoint["name"].endswith(SERVICE_IDENTIFIER) + ): + continue + + response = delete_model_endpoint(endpoint["name"], user_id) + assert response["deleted"] + print(f"[{i + 1}/{len(endpoints)}] Deleted {endpoint=}") + + +# Wait up to 5 minutes (300 seconds) for the gateway to be ready. +@retry(stop=stop_after_attempt(30), wait=wait_fixed(10)) +def ensure_gateway_ready(): + response = requests.get(f"{BASE_PATH}/healthz") + assert response.ok + + +# Wait up to 10 minutes (600 seconds) for the pods to spin up. +@retry(stop=stop_after_attempt(200), wait=wait_fixed(3)) +def ensure_nonzero_available_workers(endpoint_name: str, user_id: str): + simple_endpoint = get_model_endpoint(endpoint_name, user_id) + assert simple_endpoint.get("deployment_state", {}).get("available_workers", 0) + + +# Wait up to 20 minutes (1200 seconds) for the pods to spin up. +@retry(stop=stop_after_attempt(120), wait=wait_fixed(10)) +def ensure_nonzero_available_llm_workers(endpoint_name: str, user_id: str): + simple_endpoint = get_llm_model_endpoint(endpoint_name, user_id) + assert simple_endpoint["spec"].get("deployment_state", {}).get("available_workers", 0) + + +def ensure_inference_task_response_is_correct(response: Dict[str, Any], return_pickled: bool): + print(response) + assert response["status"] == "SUCCESS" + assert response["traceback"] is None + if return_pickled: + assert response["result"]["result_url"].startswith("s3://") + else: + assert response["result"] == {"result": '{"y": 1}'} + + +def ensure_llm_task_response_is_correct( + response: Dict[str, Any], + required_output_fields: Optional[List[str]], + response_text_regex: Optional[str], +): + print(response) + assert response["output"] is not None + + if required_output_fields is not None: + for field in required_output_fields: + assert field in response["output"] + + if response_text_regex is not None: + assert re.search(response_text_regex, response["output"]["text"]) + + +def ensure_llm_task_stream_response_is_correct( + response: str, + required_output_fields: Optional[List[str]], + response_text_regex: Optional[str], +): + # parse response + # data has format "data: \n\ndata: \n\n" + # We want to get a list of dictionaries parsing out the 'data:' field + parsed_response = [ + json.loads(r.split("data: ")[1]) for r in response.split("\n") if "data:" in r.strip() + ] + + # Join the text field of the response + response_text = "".join([r["output"]["text"] for r in parsed_response]) + print("response text: ", response_text) + assert response_text is not None + + if response_text_regex is not None: + assert re.search(response_text_regex, response_text) + + +# Wait up to 30 seconds for the tasks to be returned. +@retry( + stop=stop_after_attempt(10), wait=wait_fixed(1), retry=retry_if_exception_type(AssertionError) +) +def ensure_all_async_tasks_success(task_ids: List[str], user_id: str, return_pickled: bool): + responses = asyncio.run(get_async_tasks(task_ids, user_id)) + for response in responses: + if response["status"] not in (TaskStatus.PENDING, TaskStatus.SUCCESS, TaskStatus.STARTED): + print(response) + raise ValueError("Task failed!") + ensure_inference_task_response_is_correct(response, return_pickled) + + +def delete_existing_endpoints( + users: Sequence[str] = DEFAULT_USERS, delete_suffix_only: bool = True +) -> None: + if len(users) == 0: + raise ValueError("Must supply at least one user!") + + # list all endpoints before attempting to delete them + print(f"[{len({users})} ] Listing all user endpoints... ({users})") + all_endpoint_info = [] + for i, u in enumerate(users): + u_endpoints = list_model_endpoints(u) + all_endpoint_info.append(u_endpoints) + k8s_endpoint_names = [ + f"launch-endpoint-id-{endpoint['id'].replace('_', '-')}" for endpoint in u_endpoints + ] + print( + f"[{i + 1}/{len(users)}] {len(u_endpoints)} endpoints for user {u}: {k8s_endpoint_names}" + ) + + if all([len(info) == 0 for info in all_endpoint_info]): + return + + # delete the endpoints: if this fails, manually remove the dangling k8s deployments + # and delete the user's endpoints from the hosted_model_inference.endpoints table + # i.e. by default this is running the following SQL: + # + # >>>> delete from model_engine_server.endpoints where created_by in ( + # >>>> 'test00000000000000000000', + # >>>> 'test11111111111111111111', + # >>>> ) + # + time.sleep(15) # need to sleep to allow the cache to refresh + print(f"[{len({users})}] Deleting all user endpoints...") + try: + for i, u in enumerate(users): + suffix_msg = f" with suffix {SERVICE_IDENTIFIER}" if delete_suffix_only else "" + print(f"[{i + 1}/{len(users)}] Deleting all endpoints{suffix_msg} for user with ID {u}") + delete_all_endpoints(u, delete_suffix_only) + except Exception: # noqa + try: + j: str = json.dumps(all_endpoint_info, indent=2) + except Exception as j_error: # noqa + j = f"[FAILED TO JSON ENCODE {j_error}]\n{all_endpoint_info}" + barrier: str = "-" * 80 + print(f"ERROR! Deletion failed. All endpoint information:\n{barrier}\n{j}\n{barrier}") + raise + time.sleep(15) diff --git a/integration_tests/test_batch_jobs.py b/integration_tests/test_batch_jobs.py new file mode 100644 index 00000000..8f4f1dec --- /dev/null +++ b/integration_tests/test_batch_jobs.py @@ -0,0 +1,25 @@ +from .rest_api_utils import ( + CREATE_BATCH_JOB_REQUEST, + CREATE_DOCKER_IMAGE_BATCH_JOB_BUNDLE_REQUEST, + CREATE_DOCKER_IMAGE_BATCH_JOB_REQUEST, + USER_ID_0, + cancel_batch_job, + create_batch_job, + create_docker_image_batch_job, + get_or_create_docker_image_batch_job_bundle, +) +from .test_bundles import model_bundles # noqa + + +def test_di_batch_jobs(model_bundles) -> None: # noqa + get_or_create_docker_image_batch_job_bundle( + CREATE_DOCKER_IMAGE_BATCH_JOB_BUNDLE_REQUEST, USER_ID_0 + ) + create_docker_image_batch_job(CREATE_DOCKER_IMAGE_BATCH_JOB_REQUEST, USER_ID_0) + + batch_job_id = create_batch_job(CREATE_BATCH_JOB_REQUEST, USER_ID_0)["job_id"] + + # TODO: assert that batch job actually succeeds. + + cancel_response = cancel_batch_job(batch_job_id, USER_ID_0) + assert cancel_response["success"] diff --git a/integration_tests/test_bundles.py b/integration_tests/test_bundles.py new file mode 100644 index 00000000..3e8c47d8 --- /dev/null +++ b/integration_tests/test_bundles.py @@ -0,0 +1,26 @@ +import pytest +from tenacity import retry, stop_after_attempt, wait_fixed + +from .rest_api_utils import ( + CREATE_MODEL_BUNDLE_REQUEST_RUNNABLE_IMAGE, + CREATE_MODEL_BUNDLE_REQUEST_SIMPLE, + USER_ID_0, + create_model_bundle, + ensure_launch_gateway_healthy, + get_latest_model_bundle, +) + + +@pytest.fixture(scope="session") +@retry(stop=stop_after_attempt(10), wait=wait_fixed(30)) +def model_bundles(): + ensure_launch_gateway_healthy() + user = USER_ID_0 + for create_bundle_request in [ + CREATE_MODEL_BUNDLE_REQUEST_SIMPLE, + CREATE_MODEL_BUNDLE_REQUEST_RUNNABLE_IMAGE, + ]: + create_model_bundle(create_bundle_request, user, "v2") + bundle = get_latest_model_bundle(create_bundle_request["name"], user, "v2") + assert bundle["name"] == create_bundle_request["name"] + assert bundle["metadata"] == create_bundle_request["metadata"] diff --git a/integration_tests/test_completions.py b/integration_tests/test_completions.py new file mode 100644 index 00000000..aac6b213 --- /dev/null +++ b/integration_tests/test_completions.py @@ -0,0 +1,99 @@ +import asyncio +import os + +import pytest + +from .rest_api_utils import ( + CREATE_LLM_MODEL_ENDPOINT_REQUEST, + LLM_PAYLOADS_WITH_EXPECTED_RESPONSES, + USER_ID_0, + create_llm_model_endpoint, + create_llm_streaming_tasks, + create_llm_sync_tasks, + delete_llm_model_endpoint, + ensure_launch_gateway_healthy, + ensure_llm_task_response_is_correct, + ensure_llm_task_stream_response_is_correct, + ensure_n_ready_private_llm_endpoints_short, + ensure_nonzero_available_llm_workers, +) + +TEST_INFERENCE_FRAMEWORK = os.environ.get("TEST_INFERENCE_FRAMEWORK", None) +TEST_INFERENCE_FRAMEWORK_IMAGE_TAG = os.environ.get("TEST_INFERENCE_FRAMEWORK_IMAGE_TAG", None) +print(f"TEST_INFERENCE_FRAMEWORK={TEST_INFERENCE_FRAMEWORK}") + + +@pytest.mark.skipif( + (not TEST_INFERENCE_FRAMEWORK) or (not TEST_INFERENCE_FRAMEWORK_IMAGE_TAG), + reason="Skip unless running inference framework tests", +) +def test_completions(capsys): + ensure_launch_gateway_healthy() + with capsys.disabled(): + try: + user = USER_ID_0 + create_endpoint_request = CREATE_LLM_MODEL_ENDPOINT_REQUEST + + print(f"Creating {create_endpoint_request['name']} model endpoint...") + create_llm_model_endpoint( + create_endpoint_request, + user, + TEST_INFERENCE_FRAMEWORK, + TEST_INFERENCE_FRAMEWORK_IMAGE_TAG, + ) + ensure_n_ready_private_llm_endpoints_short(1, user) + ensure_nonzero_available_llm_workers(create_endpoint_request["name"], user) + + for ( + completions_payload, + required_output_fields, + response_text_regex, + ) in LLM_PAYLOADS_WITH_EXPECTED_RESPONSES: + print( + f"Sending sync tasks to {create_endpoint_request['name']} for user {user}, {completions_payload=}..." + ) + try: + task_responses = asyncio.run( + create_llm_sync_tasks( + create_endpoint_request["name"], + [completions_payload], + user, + ) + ) + for response in task_responses: + ensure_llm_task_response_is_correct( + response, required_output_fields, response_text_regex + ) + except Exception as e: + if hasattr(e, "response") and e.response.status_code // 100 == 4: + print(f"Got 4xx status code for {completions_payload=}, which is expected") + else: + raise e + + for ( + completions_payload, + required_output_fields, + response_text_regex, + ) in LLM_PAYLOADS_WITH_EXPECTED_RESPONSES: + print( + f"Sending streaming tasks to {create_endpoint_request['name']} for user {user}, {completions_payload=}..." + ) + try: + task_responses = asyncio.run( + create_llm_streaming_tasks( + create_endpoint_request["name"], + [completions_payload], + user, + ) + ) + for response in task_responses: + ensure_llm_task_stream_response_is_correct( + response, required_output_fields, response_text_regex + ) + except Exception as e: + if hasattr(e, "response") and e.response.status_code // 100 == 4: + print(f"Got 4xx status code for {completions_payload=}, which is expected") + else: + raise e + finally: + delete_llm_model_endpoint(create_endpoint_request["name"], user) diff --git a/integration_tests/test_docs.py b/integration_tests/test_docs.py new file mode 100644 index 00000000..1b5ff941 --- /dev/null +++ b/integration_tests/test_docs.py @@ -0,0 +1,234 @@ +# Ignore lint errors for f-strings because the f-strings are actually regex expressions. +# flake8: noqa: W605 +import importlib.util +import os +import re +from pathlib import Path +from textwrap import dedent + +import pytest +from _pytest.assertion.rewrite import AssertionRewritingHook + +from .rest_api_utils import ( + BASE_PATH, + SERVICE_IDENTIFIER, + delete_existing_endpoints, + ensure_gateway_ready, +) + +ROOT_DIR = Path(__file__).parent.parent + +TEST_SKIP_MAGIC_STRING = "# test='skip'" + + +@pytest.fixture +def tmp_work_path(tmp_path: Path): + """ + Create a temporary working directory. + """ + previous_cwd = Path.cwd() + os.chdir(tmp_path) + + yield tmp_path + + os.chdir(previous_cwd) + + +class SetEnv: + def __init__(self): + self.envars = set() + + def __call__(self, name, value): + self.envars.add(name) + os.environ[name] = value + + def clear(self): + for n in self.envars: + os.environ.pop(n) + + +@pytest.fixture +def env(): + setenv = SetEnv() + + yield setenv + + setenv.clear() + + +@pytest.fixture() +def integration_test_user_id() -> str: + return os.getenv("TEST_USER_ID", "fakeuser") + + +def modify_source(source: str) -> str: + """Modify the source code from docs to be compatible with the integration tests.""" + + # Ensure the correct base path is used + source = re.sub( + r"get_launch_client\((.*)\)\n", + rf'get_launch_client(\g<1>, gateway_endpoint="{BASE_PATH}")\n', + source, + ) + source = re.sub( + r"LaunchClient\((.*)\)\n", + rf'LaunchClient(\g<1>, endpoint="{BASE_PATH}")\n', + source, + ) + + # Add suffix to avoid name collisions + source = re.sub( + r"('endpoint_name'|\"endpoint_name\"): ('([\w-]+)'|\"([\w-]+)\")", + rf"'endpoint_name': '\g<3>\g<4>-{SERVICE_IDENTIFIER}'", + source, + ) + source = re.sub( + r"endpoint_name=('([\w-]+)'|\"([\w-]+)\")", + rf"endpoint_name='\g<2>\g<3>-{SERVICE_IDENTIFIER}'", + source, + ) + source = re.sub( + r"get_model_endpoint\(\"([\w-]+)\"\)", + rf'get_model_endpoint("\g<1>-{SERVICE_IDENTIFIER}")', + source, + ) + + # Set particular tag values for cost tracking + source = re.sub(r"('team'|\"team\"): ('\w+'|\"\w+\")", r"'team': 'infra'", source) + source = re.sub( + r"('product'|\"product\"): ('\w+'|\"\w+\")", + r"'product': 'launch-integration-test'", + source, + ) + + # Fill in empty values in docs + source = re.sub(r'"repository": "..."', '"repository": "launch_rearch"', source) + source = re.sub( + r'"tag": "..."', '"tag": "11d9d42047cc9a0c6435b19e5e91bc7e0ad31efc-cpu"', source + ) + source = re.sub( + r'"command": ...', + """"command": [ + "dumb-init", + "--", + "ddtrace-run", + "run-service", + "--config", + "/install/launch_rearch/config/service--user_defined_code.yaml", + "--concurrency", + "1", + "--http", + "production", + "--port", + "5005", + ]""", + source, + ) + source = re.sub( + r'"streaming_command": ...', + """"streaming_command": [ + "dumb-init", + "--", + "ddtrace-run", + "run-streamer", + "--config", + "/install/std-ml-srv/tests/resources/example_echo_streaming_service_configuration.yaml", + "--concurrency", + "1", + "--http-mode", + "production", + "--port", + "5005", + ]""", + source, + ) + return source + + +@pytest.fixture +def import_execute(request, tmp_work_path: Path): + def _import_execute(module_name: str, source: str, rewrite_assertions: bool = False): + if rewrite_assertions: + loader = AssertionRewritingHook(config=request.config) + loader.mark_rewrite(module_name) + else: + loader = None + + module_path = tmp_work_path / f"{module_name}.py" + modified_source = modify_source(source) + module_path.write_text(modified_source) + spec = importlib.util.spec_from_file_location("__main__", str(module_path), loader=loader) + module = importlib.util.module_from_spec(spec) + try: + spec.loader.exec_module(module) + except KeyboardInterrupt: + print("KeyboardInterrupt") + + return _import_execute + + +def extract_code_chunks(path: Path, text: str, offset: int): + rel_path = path.relative_to(ROOT_DIR) + for m_code in re.finditer(r"```(.*?)$\n(.*?)\n( *)```", text, flags=re.M | re.S): + prefix = m_code.group(1).lower() + if not prefix.startswith(("py", "{.py")): + continue + + start_line = offset + text[: m_code.start()].count("\n") + 1 + code = dedent(m_code.group(2)) + end_line = start_line + code.count("\n") + 1 + source = "\n" * start_line + code + if TEST_SKIP_MAGIC_STRING in prefix or TEST_SKIP_MAGIC_STRING in code: + source = "__skip__" + yield pytest.param( + f"{path.stem}_{start_line}_{end_line}", source, id=f"{rel_path}:{start_line}-{end_line}" + ) + + +def generate_code_chunks(*directories: str): + for d in directories: + for path in (ROOT_DIR / d).glob("**/*"): + if path.suffix == ".py": + code = path.read_text() + for m_docstring in re.finditer(r'(^\s*)r?"""$(.*?)\1"""', code, flags=re.M | re.S): + start_line = code[: m_docstring.start()].count("\n") + docstring = m_docstring.group(2) + yield from extract_code_chunks(path, docstring, start_line) + elif path.suffix == ".md": + # TODO: remove this hack to skip llms.md + if "llms.md" in path.name: + continue + code = path.read_text() + yield from extract_code_chunks(path, code, 0) + + +# Assumes that launch-python-client is cloned at `models/launch-python-client` +@pytest.mark.parametrize( + "module_name,source_code", + generate_code_chunks( + "launch-python-client/docs", + "launch-python-client/launch", + "launch_internal/docs", + "launch_internal/launch_internal", + ), +) +def test_docs_examples( + module_name, + source_code, + import_execute, + env, + integration_test_user_id, +): + if source_code == "__skip__": + pytest.skip("test='skip' on code snippet") + + env("LAUNCH_API_KEY", os.getenv("LAUNCH_TEST_API_KEY", integration_test_user_id)) + + ensure_gateway_ready() + + try: + import_execute(module_name, source_code, True) + except Exception: + raise + finally: + delete_existing_endpoints() diff --git a/integration_tests/test_endpoints.py b/integration_tests/test_endpoints.py new file mode 100644 index 00000000..5d7eae2a --- /dev/null +++ b/integration_tests/test_endpoints.py @@ -0,0 +1,254 @@ +import asyncio +import time + +import pytest +from tenacity import RetryError, retry, retry_if_exception_type, stop_after_attempt, wait_fixed + +from .rest_api_utils import ( + CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE, + CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_SIMPLE, + CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE, + CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE, + INFERENCE_PAYLOAD, + INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE, + INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE, + UPDATE_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE, + UPDATE_MODEL_ENDPOINT_REQUEST_SIMPLE, + USER_ID_0, + create_async_tasks, + create_model_endpoint, + create_streaming_tasks, + create_sync_tasks, + delete_existing_endpoints, + delete_model_endpoint, + ensure_all_async_tasks_success, + ensure_gateway_ready, + ensure_inference_task_response_is_correct, + ensure_n_ready_endpoints_long, + ensure_n_ready_endpoints_short, + ensure_nonzero_available_workers, + get_model_endpoint, + update_model_endpoint, +) + + +@pytest.fixture(autouse=True) +def delete_endpoints(capsys): + try: + ensure_gateway_ready() + delete_existing_endpoints() + except Exception: + with capsys.disabled(): + print("Endpoint deletion failed") + + +@retry(stop=stop_after_attempt(3), wait=wait_fixed(10), retry=retry_if_exception_type(RetryError)) +def ensure_async_inference_works(user, create_endpoint_request, inference_payload, return_pickled): + print( + f"Sending async tasks to {create_endpoint_request['name']} for user {user}, {inference_payload=}, {return_pickled=} ..." + ) + task_ids = asyncio.run( + create_async_tasks( + create_endpoint_request["name"], + [inference_payload] * 3, + user, + ) + ) + print("Retrieving async task results...") + ensure_nonzero_available_workers(create_endpoint_request["name"], user) + ensure_all_async_tasks_success(task_ids, user, return_pickled) + + +@retry(stop=stop_after_attempt(3), wait=wait_fixed(20)) +def ensure_endpoint_updated(create_endpoint_request, update_endpoint_request, user): + endpoint = get_model_endpoint(create_endpoint_request["name"], user) + assert endpoint["resource_state"]["cpus"] == update_endpoint_request["cpus"] + assert endpoint["resource_state"]["memory"] == update_endpoint_request["memory"] + assert endpoint["deployment_state"]["max_workers"] == update_endpoint_request["max_workers"] + + +@pytest.mark.parametrize( + "create_endpoint_request,update_endpoint_request,inference_requests", + [ + ( + CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_SIMPLE, + UPDATE_MODEL_ENDPOINT_REQUEST_SIMPLE, + [ + (INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE, True), + (INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE, False), + ], + ), + ( + CREATE_ASYNC_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE, + UPDATE_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE, + [(INFERENCE_PAYLOAD, False)], + ), + ], +) +def test_async_model_endpoint( + capsys, create_endpoint_request, update_endpoint_request, inference_requests +): + with capsys.disabled(): + try: + user = USER_ID_0 + print(f"Creating {create_endpoint_request['name']} model endpoint...") + create_model_endpoint(create_endpoint_request, user) + ensure_n_ready_endpoints_long(1, user) + + print(f"Updating {create_endpoint_request['name']} model endpoint...") + update_model_endpoint( + create_endpoint_request["name"], + update_endpoint_request, + user, + ) + # Let the cache update + time.sleep(60) + # Endpoint builds should be cached now. + ensure_n_ready_endpoints_short(1, user) + + print("Checking endpoint state...") + ensure_endpoint_updated(create_endpoint_request, update_endpoint_request, user) + + time.sleep(20) + + for inference_payload, return_pickled in inference_requests: + ensure_async_inference_works( + user, create_endpoint_request, inference_payload, return_pickled + ) + finally: + delete_model_endpoint(create_endpoint_request["name"], user) + + +def test_sync_model_endpoint(capsys): + with capsys.disabled(): + try: + user = USER_ID_0 + create_endpoint_request = CREATE_SYNC_MODEL_ENDPOINT_REQUEST_SIMPLE + update_endpoint_request = UPDATE_MODEL_ENDPOINT_REQUEST_SIMPLE + inference_requests = [ + (INFERENCE_PAYLOAD_RETURN_PICKLED_TRUE, True), + (INFERENCE_PAYLOAD_RETURN_PICKLED_FALSE, False), + ] + + print(f"Creating {create_endpoint_request['name']} model endpoint...") + create_model_endpoint(create_endpoint_request, user) + ensure_n_ready_endpoints_short(1, user) + + print(f"Updating {create_endpoint_request['name']} model endpoint...") + update_model_endpoint( + create_endpoint_request["name"], + update_endpoint_request, + user, + ) + # Let the cache update + time.sleep(30) + # Endpoint builds should be cached now. + ensure_n_ready_endpoints_short(1, user) + ensure_nonzero_available_workers(create_endpoint_request["name"], user) + + print("Checking endpoint state...") + endpoint = get_model_endpoint(create_endpoint_request["name"], user) + assert endpoint["resource_state"]["cpus"] == update_endpoint_request["cpus"] + assert endpoint["resource_state"]["memory"] == update_endpoint_request["memory"] + assert ( + endpoint["deployment_state"]["max_workers"] + == update_endpoint_request["max_workers"] + ) + + time.sleep(10) + + for inference_payload, return_pickled in inference_requests: + print( + f"Sending sync tasks to {create_endpoint_request['name']} for user {user}, {inference_payload=}, {return_pickled=} ..." + ) + task_responses = asyncio.run( + create_sync_tasks( + create_endpoint_request["name"], + [inference_payload], + user, + ) + ) + for response in task_responses: + ensure_inference_task_response_is_correct(response, return_pickled) + finally: + delete_model_endpoint(create_endpoint_request["name"], user) + + +def test_sync_streaming_model_endpoint(capsys): + with capsys.disabled(): + try: + user = USER_ID_0 + create_endpoint_request = CREATE_SYNC_STREAMING_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE + update_endpoint_request = UPDATE_MODEL_ENDPOINT_REQUEST_RUNNABLE_IMAGE + + print(f"Creating {create_endpoint_request['name']} model endpoint...") + create_model_endpoint(create_endpoint_request, user) + ensure_n_ready_endpoints_short(1, user) + + print(f"Updating {create_endpoint_request['name']} model endpoint...") + update_model_endpoint( + create_endpoint_request["name"], + update_endpoint_request, + user, + ) + # Let the cache update + time.sleep(30) + # Endpoint builds should be cached now. + ensure_n_ready_endpoints_short(1, user) + ensure_nonzero_available_workers(create_endpoint_request["name"], user) + + print("Checking endpoint state...") + endpoint = get_model_endpoint(create_endpoint_request["name"], user) + assert endpoint["resource_state"]["cpus"] == update_endpoint_request["cpus"] + assert endpoint["resource_state"]["memory"] == update_endpoint_request["memory"] + assert ( + endpoint["deployment_state"]["max_workers"] + == update_endpoint_request["max_workers"] + ) + + time.sleep(5) + + print(f"Sending sync tasks to {create_endpoint_request['name']} for user {user} ...") + task_responses = asyncio.run( + create_sync_tasks( + create_endpoint_request["name"], + [INFERENCE_PAYLOAD] * 3, + user, + ) + ) + for response in task_responses: + ensure_inference_task_response_is_correct(response, False) + + print( + f"Sending streaming tasks to {create_endpoint_request['name']} for user {user} ..." + ) + task_responses = asyncio.run( + create_streaming_tasks( + create_endpoint_request["name"], + [INFERENCE_PAYLOAD] * 5, + user, + ) + ) + for response in task_responses: + assert ( + response.strip() + == 'data: {"status":"SUCCESS","result":{"result":{"y":1}},"traceback":null}' + ) + finally: + delete_model_endpoint(create_endpoint_request["name"], user) + + +@pytest.mark.skipif( + reason="Need to update the following test to hit remote service to be integration test" +) +def test_models_tokenizers() -> None: + from model_engine_server.infra.gateways.s3_llm_artifact_gateway import S3LLMArtifactGateway + from model_engine_server.infra.repositories import LiveTokenizerRepository + from model_engine_server.infra.repositories.live_tokenizer_repository import ( + SUPPORTED_MODELS_INFO, + ) + + llm_artifact_gateway = S3LLMArtifactGateway() + tokenizer_repository = LiveTokenizerRepository(llm_artifact_gateway=llm_artifact_gateway) + for model_name in SUPPORTED_MODELS_INFO: + tokenizer_repository.load_tokenizer(model_name) diff --git a/integration_tests/test_file.py b/integration_tests/test_file.py new file mode 100644 index 00000000..53c10345 --- /dev/null +++ b/integration_tests/test_file.py @@ -0,0 +1,27 @@ +from .rest_api_utils import ( # list_files, delete_file_by_id, + get_file_by_id, + get_file_content_by_id, + upload_file, +) + + +def test_files() -> None: + user = "62bc820451dbea002b1c5421" # CDS needs proper user ID + + upload_response = upload_file(open(__file__, "rb"), user) + file_id = upload_response["id"] + + content = get_file_content_by_id(file_id, user) + assert content["id"] == file_id + assert content["content"] + + get_response = get_file_by_id(file_id, user) + assert get_response["id"] == file_id + assert get_response["filename"] == "test_file.py" + + # TODO: add tests back + # list_response = list_files(user) + # assert len(list_response["files"]) > 0 + + # delete_response = delete_file_by_id(file_id, user) + # assert delete_response["deleted"] diff --git a/integration_tests/test_fine_tunes.py b/integration_tests/test_fine_tunes.py new file mode 100644 index 00000000..89d9e447 --- /dev/null +++ b/integration_tests/test_fine_tunes.py @@ -0,0 +1,73 @@ +import json +import os +import time + +import boto3 +import pytest +import smart_open + +from .rest_api_utils import ( + CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST, + CREATE_FINE_TUNE_REQUEST, + USER_ID_0, + cancel_fine_tune_by_id, + create_docker_image_batch_job_bundle, + create_fine_tune, + get_fine_tune_by_id, + list_fine_tunes, +) + +MAX_RETRIES = 10 + + +@pytest.mark.skipif( + not os.getenv("FINE_TUNE_TEST_READY"), + reason="Skipping fine tune tests when test templates are not set up.", +) +def test_fine_tunes() -> None: + di_batch_job_id = create_docker_image_batch_job_bundle( + CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST, USER_ID_0 + )["docker_image_batch_job_bundle_id"] + data = { + "test_base_model-lora": { + "docker_image_batch_job_bundle_id": di_batch_job_id, + "launch_bundle_config": {}, + "launch_endpoint_config": {}, + "default_hparams": {}, + "required_params": [], + } + } + + if os.getenv("CIRCLECI") == "true": + session = boto3.Session() + aws_s3_bucket = os.getenv("CIRCLECI_AWS_S3_BUCKET") + client = session.client("s3") + with smart_open.open( + f"s3://{aws_s3_bucket}/fine_tune_repository", + "w", + transport_params={"client": client}, + ) as f: + json.dump(data, f) + + create_response = create_fine_tune(CREATE_FINE_TUNE_REQUEST, USER_ID_0) + fine_tune_id = create_response["id"] + + get_response = get_fine_tune_by_id(fine_tune_id, USER_ID_0) + num_retries = 0 + while get_response["status"] not in ["SUCCESS", "FAILURE"]: + if num_retries >= MAX_RETRIES: + raise Exception("Fine tune job did not complete in time.") + num_retries += 1 + get_response = get_fine_tune_by_id(fine_tune_id, USER_ID_0) + time.sleep(10) + assert get_response["id"] == fine_tune_id + assert get_response["status"] == "SUCCESS" + + list_response_0_before = list_fine_tunes(USER_ID_0) + num_jobs = len(list_response_0_before["jobs"]) + assert num_jobs >= 1 + + cancel_fine_tune_by_id(fine_tune_id, USER_ID_0) + + list_response_0_after = list_fine_tunes(USER_ID_0) + assert len(list_response_0_after["jobs"]) == num_jobs - 1 diff --git a/mkdocs.yml b/mkdocs.yml index 45a7329c..b719a2cb 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -46,6 +46,7 @@ nav: - "API Reference": api/python_client.md - "Data Type Reference": api/data_types.md - "Error handling": api/error_handling.md + - "Integrations": integrations.md - "Pricing": pricing.md - "Contributing": contributing.md # - "FAQ": faq.md @@ -87,17 +88,17 @@ markdown_extensions: - neoteroi.cards - footnotes +watch: + - clients/python/llmengine + plugins: - search - mkdocstrings: - watch: - - clients/python/llmengine handlers: python: options: separate_signature: true line_length: 60 - rendering: show_root_heading: true show_root_full_path: false show_source: false diff --git a/server/Dockerfile b/model-engine/Dockerfile similarity index 54% rename from server/Dockerfile rename to model-engine/Dockerfile index 59d466cb..45cd9630 100644 --- a/server/Dockerfile +++ b/model-engine/Dockerfile @@ -1,6 +1,6 @@ # syntax = docker/dockerfile:experimental -FROM python:3.8.8-slim as llm-engine +FROM python:3.10.15-slim as model-engine WORKDIR /workspace RUN apt-get update && apt-get install -y \ @@ -18,7 +18,6 @@ RUN apt-get update && apt-get install -y \ python3-dev \ gcc \ build-essential \ - postgresql \ telnet \ && rm -rf /var/lib/apt/lists/* @@ -26,35 +25,33 @@ RUN curl -Lo /bin/aws-iam-authenticator https://github.com/kubernetes-sigs/aws-i RUN chmod +x /bin/aws-iam-authenticator # Install kubectl -RUN curl -LO "https://dl.k8s.io/release/v1.17.9/bin/linux/amd64/kubectl" \ +RUN curl -LO "https://dl.k8s.io/release/v1.23.13/bin/linux/amd64/kubectl" \ && chmod +x kubectl \ && mv kubectl /usr/local/bin/kubectl # Pin pip version -RUN pip install pip==23.0.1 +RUN pip install pip==24.2 RUN chmod -R 777 /workspace -## grab llm_engine_server project (w/ requirements install layer caching) -WORKDIR /workspace/server/ -COPY server/requirements-test.txt /workspace/server/requirements-test.txt -COPY server/requirements.txt /workspace/server/requirements.txt -COPY server/requirements_override.txt /workspace/server/requirements_override.txt +# Install AWS CLI +RUN pip install awscli==1.34.28 --no-cache-dir + +## grab model_engine_server project (w/ requirements install layer caching) +WORKDIR /workspace/model-engine/ +COPY model-engine/requirements-test.txt /workspace/model-engine/requirements-test.txt +COPY model-engine/requirements.txt /workspace/model-engine/requirements.txt +COPY model-engine/requirements_override.txt /workspace/model-engine/requirements_override.txt RUN pip install -r requirements-test.txt --no-cache-dir RUN pip install -r requirements.txt --no-cache-dir RUN pip install -r requirements_override.txt --no-cache-dir -COPY server/pyproject.toml /workspace/server/pyproject.toml -COPY server/setup.py /workspace/server/setup.py -COPY server/llm_engine_server /workspace/server/llm_engine_server +COPY model-engine/setup.py /workspace/model-engine/setup.py +COPY model-engine/model_engine_server /workspace/model-engine/model_engine_server RUN pip install -e . +COPY integration_tests /workspace/integration_tests + WORKDIR /workspace ENV PYTHONPATH /workspace ENV WORKSPACE /workspace EXPOSE 5000 -EXPOSE 5001 -EXPOSE 5002 -EXPOSE 5005 - -RUN useradd -m user -s /bin/bash -USER user diff --git a/model-engine/README.md b/model-engine/README.md new file mode 100644 index 00000000..3f6b9579 --- /dev/null +++ b/model-engine/README.md @@ -0,0 +1,49 @@ +# Model Engine + +The Model Engine is an API server that allows users to create, deploy, edit, +and delete machine learning endpoints. It consists of two main architectural +components: + +- The [gateway](./model_engine_server/entrypoints/start_fastapi_server.py) + provides a REST API for users to interact with. The routes of the REST API are + defined in [`model_engine_server.api`](./model_engine_server/api). +- The [`model_engine_server.service_builder`](./model_engine_server/service_builder) + package is the part of the code that creates the inference pods. It is the + endpoint builder. When we do a `POST` request to `/endpoints`, this gets run. + It gets run when users create or edit endpoints with `[POST,PUT] /v1/model-endpoints` + +There are two other microservices: + +- The [kubernetes cache](./model_engine_server/entrypoints/k8s_cache.py) + stores endpoint metadata on Redis so that Model Engine does not overload the API + server. +- The celery autoscaler (link TBD) automatically scales + the number of inference pods based on the number of requests for async endpoints. + +## Getting started + +Be sure to install the global `../requirements-dev.txt` first prior +to any installations of requirements in this directory +(`pip install -r ../requirements-dev.txt`), as well as the pre-commit hooks +(`pre-commit install` in the `llm-engine` root folder). Then, install the +requirements files and this folder as editable + +```bash +pip install -r requirements.txt && \ + pip install -r requirements-test.txt && \ + pip install -r requirements_override.txt && \ + pip install -e . +``` + +Run `mypy . --install-types` to set up mypy. + +## Testing + +Most of the business logic in Model Engine should contain unit tests, located in +[`tests/unit`](./tests/unit). To run the tests, run `pytest`. + +## Generating OpenAI types +We've decided to make our V2 APIs OpenAI compatible. We generate the +corresponding Pydantic models: +1. Fetch the OpenAPI spec from https://github.com/openai/openai-openapi/blob/master/openapi.yaml +2. Run scripts/generate-openai-types.sh diff --git a/server/llm_engine_server/core/__init__.py b/model-engine/model_engine_server/__init__.py similarity index 100% rename from server/llm_engine_server/core/__init__.py rename to model-engine/model_engine_server/__init__.py diff --git a/server/llm_engine_server/core/auth/__init__.py b/model-engine/model_engine_server/api/__init__.py similarity index 100% rename from server/llm_engine_server/core/auth/__init__.py rename to model-engine/model_engine_server/api/__init__.py diff --git a/model-engine/model_engine_server/api/app.py b/model-engine/model_engine_server/api/app.py new file mode 100644 index 00000000..2f7a4b0e --- /dev/null +++ b/model-engine/model_engine_server/api/app.py @@ -0,0 +1,127 @@ +import os +import traceback +import uuid +from datetime import datetime +from pathlib import Path + +import pytz +from fastapi import FastAPI, HTTPException, Request, Response +from fastapi.responses import JSONResponse +from fastapi.staticfiles import StaticFiles +from model_engine_server.api.batch_jobs_v1 import batch_job_router_v1 +from model_engine_server.api.dependencies import get_or_create_aioredis_pool +from model_engine_server.api.docker_image_batch_job_bundles_v1 import ( + docker_image_batch_job_bundle_router_v1, +) +from model_engine_server.api.files_v1 import file_router_v1 +from model_engine_server.api.llms_v1 import llm_router_v1 +from model_engine_server.api.model_bundles_v1 import model_bundle_router_v1 +from model_engine_server.api.model_bundles_v2 import model_bundle_router_v2 +from model_engine_server.api.model_endpoints_docs_v1 import model_endpoints_docs_router_v1 +from model_engine_server.api.model_endpoints_v1 import model_endpoint_router_v1 +from model_engine_server.api.tasks_v1 import inference_task_router_v1 +from model_engine_server.api.triggers_v1 import trigger_router_v1 +from model_engine_server.api.v2 import llm_router_v2 +from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) +from starlette.middleware import Middleware +from starlette.middleware.base import BaseHTTPMiddleware + +logger = make_logger(logger_name()) + +# Allows us to make the Uvicorn worker concurrency in model_engine_server/api/worker.py very high +MAX_CONCURRENCY = 500 + +concurrency_limiter = MultiprocessingConcurrencyLimiter( + concurrency=MAX_CONCURRENCY, fail_on_concurrency_limit=True +) + +healthcheck_routes = ["/healthcheck", "/healthz", "/readyz"] + + +class CustomMiddleware(BaseHTTPMiddleware): + async def dispatch(self, request: Request, call_next): + try: + LoggerTagManager.set(LoggerTagKey.REQUEST_ID, str(uuid.uuid4())) + LoggerTagManager.set(LoggerTagKey.REQUEST_SIZE, request.headers.get("content-length")) + # we intentionally exclude healthcheck routes from the concurrency limiter + if request.url.path in healthcheck_routes: + return await call_next(request) + with concurrency_limiter: + return await call_next(request) + except HTTPException as e: + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + return JSONResponse( + status_code=e.status_code, + content={ + "error": e.detail, + "timestamp": timestamp, + }, + ) + except Exception as e: + tb_str = traceback.format_exception(e) + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + structured_log = { + "error": str(e), + "request_id": str(request_id), + "traceback": "".join(tb_str), + } + logger.error("Unhandled exception: %s", structured_log) + return JSONResponse( + status_code=500, + content={ + "error": "Internal error occurred. Our team has been notified.", + "timestamp": timestamp, + "request_id": request_id, + }, + ) + + +app = FastAPI( + title="launch", + version="1.0.0", + redoc_url="/api", + middleware=[Middleware(CustomMiddleware)], +) + +app.include_router(batch_job_router_v1) +app.include_router(inference_task_router_v1) +app.include_router(model_bundle_router_v1) +app.include_router(model_bundle_router_v2) +app.include_router(model_endpoint_router_v1) +app.include_router(model_endpoints_docs_router_v1) +app.include_router(docker_image_batch_job_bundle_router_v1) +app.include_router(llm_router_v1) +app.include_router(file_router_v1) +app.include_router(trigger_router_v1) +app.include_router(llm_router_v2) + + +# TODO: Remove this once we have a better way to serve internal docs +INTERNAL_DOCS_PATH = str(Path(__file__).parents[3] / "launch_internal/site") +if os.path.exists(INTERNAL_DOCS_PATH): + app.mount( + "/python-docs", + StaticFiles(directory=INTERNAL_DOCS_PATH, html=True), + name="python-docs", + ) + + +@app.on_event("startup") +def load_redis(): + get_or_create_aioredis_pool() + + +def healthcheck() -> Response: + """Returns 200 if the app is healthy.""" + return Response(status_code=200) + + +for endpoint in healthcheck_routes: + app.get(endpoint)(healthcheck) diff --git a/server/llm_engine_server/api/batch_jobs_v1.py b/model-engine/model_engine_server/api/batch_jobs_v1.py similarity index 79% rename from server/llm_engine_server/api/batch_jobs_v1.py rename to model-engine/model_engine_server/api/batch_jobs_v1.py index 86f46ff9..1724c5a7 100644 --- a/server/llm_engine_server/api/batch_jobs_v1.py +++ b/model-engine/model_engine_server/api/batch_jobs_v1.py @@ -1,47 +1,48 @@ -from fastapi import APIRouter, Depends, HTTPException -from llm_engine_server.api.dependencies import ( +from typing import Optional + +from fastapi import APIRouter, Depends, HTTPException, Query +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.batch_jobs import ( +from model_engine_server.common.dtos.batch_jobs import ( CreateBatchJobV1Request, CreateBatchJobV1Response, CreateDockerImageBatchJobV1Request, CreateDockerImageBatchJobV1Response, GetBatchJobV1Response, GetDockerImageBatchJobV1Response, + ListDockerImageBatchJobsV1Response, UpdateBatchJobV1Request, UpdateBatchJobV1Response, UpdateDockerImageBatchJobV1Request, UpdateDockerImageBatchJobV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, + EndpointLabelsException, + EndpointResourceInvalidRequestException, ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import ( - EndpointLabelsException, - EndpointResourceInvalidRequestException, -) -from llm_engine_server.domain.use_cases.batch_job_use_cases import ( +from model_engine_server.domain.use_cases.batch_job_use_cases import ( CreateBatchJobV1UseCase, CreateDockerImageBatchJobV1UseCase, GetBatchJobV1UseCase, GetDockerImageBatchJobV1UseCase, + ListDockerImageBatchJobsV1UseCase, UpdateBatchJobV1UseCase, UpdateDockerImageBatchJobV1UseCase, ) batch_job_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @batch_job_router_v1.post("/batch-jobs", response_model=CreateBatchJobV1Response) @@ -53,7 +54,6 @@ async def create_batch_job( """ Runs a batch job. """ - add_trace_resource_name("batch_jobs_post") logger.info(f"POST /batch-jobs with {request} for {auth}") try: use_case = CreateBatchJobV1UseCase( @@ -83,7 +83,6 @@ async def get_batch_job( """ Gets a batch job. """ - add_trace_resource_name("batch_jobs_get") logger.info(f"GET /batch-jobs/{batch_job_id} for {auth}") try: use_case = GetBatchJobV1UseCase(batch_job_service=external_interfaces.batch_job_service) @@ -105,7 +104,6 @@ async def update_batch_job( """ Updates a batch job. """ - add_trace_resource_name("batch_jobs_put") logger.info(f"PUT /batch-jobs/{batch_job_id} for {auth}") try: use_case = UpdateBatchJobV1UseCase(batch_job_service=external_interfaces.batch_job_service) @@ -125,8 +123,6 @@ async def create_docker_image_batch_job( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> CreateDockerImageBatchJobV1Response: - - add_trace_resource_name("batch_jobs_di_create") logger.info(f"POST /docker-image-batch-jobs with {request} for {auth}") try: use_case = CreateDockerImageBatchJobV1UseCase( @@ -151,23 +147,20 @@ async def create_docker_image_batch_job( ) from exc except EndpointResourceInvalidRequestException as exc: raise HTTPException( - status_code=400, - detail=f"Final endpoint resources requested is invalid: {exc}", + status_code=400, detail=f"Final endpoint resources requested is invalid: {exc}" ) from exc except EndpointLabelsException as exc: raise HTTPException(status_code=400, detail=str(exc)) from exc @batch_job_router_v1.get( - "/docker-image-batch-jobs/{batch_job_id}", - response_model=GetDockerImageBatchJobV1Response, + "/docker-image-batch-jobs/{batch_job_id}", response_model=GetDockerImageBatchJobV1Response ) async def get_docker_image_batch_job( batch_job_id: str, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), ) -> GetDockerImageBatchJobV1Response: - add_trace_resource_name("batch_jobs_di_get") logger.info(f"GET /docker-image-batch-jobs/{batch_job_id} for {auth}") try: use_case = GetDockerImageBatchJobV1UseCase( @@ -180,9 +173,31 @@ async def get_docker_image_batch_job( ) from exc +@batch_job_router_v1.get( + "/docker-image-batch-jobs", + response_model=ListDockerImageBatchJobsV1Response, +) +async def list_docker_image_batch_jobs( + trigger_id: Optional[str] = Query(default=None), + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> ListDockerImageBatchJobsV1Response: + """ + Lists docker image batch jobs spawned by trigger with given ID + """ + logger.info(f"GET /docker-image-batch-jobs?trigger_id={trigger_id}") + try: + use_case = ListDockerImageBatchJobsV1UseCase( + trigger_repository=external_interfaces.trigger_repository, + cron_job_gateway=external_interfaces.cron_job_gateway, + ) + return await use_case.execute(user=auth, trigger_id=trigger_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException(status_code=404, detail=f"Trigger {trigger_id} was not found.") from exc + + @batch_job_router_v1.put( - "/docker-image-batch-jobs/{batch_job_id}", - response_model=UpdateDockerImageBatchJobV1Response, + "/docker-image-batch-jobs/{batch_job_id}", response_model=UpdateDockerImageBatchJobV1Response ) async def update_docker_image_batch_job( batch_job_id: str, @@ -190,7 +205,6 @@ async def update_docker_image_batch_job( auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> UpdateDockerImageBatchJobV1Response: - add_trace_resource_name("batch_jobs_di_put") logger.info(f"PUT /docker-image-batch-jobs/{batch_job_id} with {request} for {auth}") try: use_case = UpdateDockerImageBatchJobV1UseCase( diff --git a/model-engine/model_engine_server/api/dependencies.py b/model-engine/model_engine_server/api/dependencies.py new file mode 100644 index 00000000..e120fbf0 --- /dev/null +++ b/model-engine/model_engine_server/api/dependencies.py @@ -0,0 +1,504 @@ +import asyncio +import os +import time +from dataclasses import dataclass +from typing import Callable, Optional + +import aioredis +from fastapi import Depends, HTTPException, status +from fastapi.security import HTTPBasic, HTTPBasicCredentials, OAuth2PasswordBearer +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.env_vars import CIRCLECI +from model_engine_server.core.auth.authentication_repository import AuthenticationRepository, User +from model_engine_server.core.auth.fake_authentication_repository import ( + FakeAuthenticationRepository, +) +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) +from model_engine_server.db.base import get_session_async, get_session_read_only_async +from model_engine_server.domain.gateways import ( + CronJobGateway, + DockerImageBatchJobGateway, + FileStorageGateway, + LLMArtifactGateway, + ModelPrimitiveGateway, + MonitoringMetricsGateway, + TaskQueueGateway, +) +from model_engine_server.domain.repositories import ( + DockerImageBatchJobBundleRepository, + DockerRepository, + LLMFineTuneEventsRepository, + ModelBundleRepository, + TokenizerRepository, + TriggerRepository, +) +from model_engine_server.domain.services import ( + BatchJobService, + LLMFineTuningService, + LLMModelEndpointService, + ModelEndpointService, +) +from model_engine_server.domain.services.llm_batch_completions_service import ( + LLMBatchCompletionsService, +) +from model_engine_server.inference.domain.gateways.streaming_storage_gateway import ( + StreamingStorageGateway, +) +from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import ( + FirehoseStreamingStorageGateway, +) +from model_engine_server.infra.gateways import ( + ABSFileStorageGateway, + ABSFilesystemGateway, + ABSLLMArtifactGateway, + ASBInferenceAutoscalingMetricsGateway, + CeleryTaskQueueGateway, + DatadogMonitoringMetricsGateway, + FakeMonitoringMetricsGateway, + LiveAsyncModelEndpointInferenceGateway, + LiveBatchJobOrchestrationGateway, + LiveBatchJobProgressGateway, + LiveCronJobGateway, + LiveDockerImageBatchJobGateway, + LiveModelEndpointInfraGateway, + LiveModelEndpointsSchemaGateway, + LiveStreamingModelEndpointInferenceGateway, + LiveSyncModelEndpointInferenceGateway, + ModelEndpointInfraGateway, + RedisInferenceAutoscalingMetricsGateway, + S3FilesystemGateway, + S3LLMArtifactGateway, +) +from model_engine_server.infra.gateways.fake_model_primitive_gateway import ( + FakeModelPrimitiveGateway, +) +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( + ASBQueueEndpointResourceDelegate, +) +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( + EndpointResourceGateway, +) +from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( + FakeQueueEndpointResourceDelegate, +) +from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( + LiveEndpointResourceGateway, +) +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, +) +from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( + SQSQueueEndpointResourceDelegate, +) +from model_engine_server.infra.gateways.s3_file_storage_gateway import S3FileStorageGateway +from model_engine_server.infra.repositories import ( + ABSFileLLMFineTuneEventsRepository, + ABSFileLLMFineTuneRepository, + ACRDockerRepository, + DbBatchJobRecordRepository, + DbDockerImageBatchJobBundleRepository, + DbModelBundleRepository, + DbModelEndpointRecordRepository, + DbTriggerRepository, + ECRDockerRepository, + FakeDockerRepository, + LiveTokenizerRepository, + LLMFineTuneRepository, + RedisModelEndpointCacheRepository, + S3FileLLMFineTuneEventsRepository, + S3FileLLMFineTuneRepository, +) +from model_engine_server.infra.services import ( + DockerImageBatchJobLLMFineTuningService, + LiveBatchJobService, + LiveModelEndpointService, +) +from model_engine_server.infra.services.live_llm_batch_completions_service import ( + LiveLLMBatchCompletionsService, +) +from model_engine_server.infra.services.live_llm_model_endpoint_service import ( + LiveLLMModelEndpointService, +) +from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session + +logger = make_logger(logger_name()) + +basic_auth = HTTPBasic(auto_error=False) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token", auto_error=False) + + +@dataclass +class ExternalInterfaces: + """ + Internal object used for aggregating various Gateway and Repository objects for dependency + injection. + """ + + docker_repository: DockerRepository + docker_image_batch_job_bundle_repository: DockerImageBatchJobBundleRepository + model_bundle_repository: ModelBundleRepository + trigger_repository: TriggerRepository + model_endpoint_service: ModelEndpointService + batch_job_service: BatchJobService + llm_model_endpoint_service: LLMModelEndpointService + llm_batch_completions_service: LLMBatchCompletionsService + llm_fine_tuning_service: LLMFineTuningService + llm_fine_tune_events_repository: LLMFineTuneEventsRepository + + resource_gateway: EndpointResourceGateway + endpoint_creation_task_queue_gateway: TaskQueueGateway + inference_task_queue_gateway: TaskQueueGateway + model_endpoint_infra_gateway: ModelEndpointInfraGateway + docker_image_batch_job_gateway: DockerImageBatchJobGateway + model_primitive_gateway: ModelPrimitiveGateway + file_storage_gateway: FileStorageGateway + filesystem_gateway: FilesystemGateway + llm_artifact_gateway: LLMArtifactGateway + cron_job_gateway: CronJobGateway + monitoring_metrics_gateway: MonitoringMetricsGateway + tokenizer_repository: TokenizerRepository + streaming_storage_gateway: StreamingStorageGateway + + +def get_default_monitoring_metrics_gateway() -> MonitoringMetricsGateway: + # dd_trace_enabled is a good enough proxy for determining if we should use Datadog + if hmi_config.dd_trace_enabled: + monitoring_metrics_gateway: MonitoringMetricsGateway = DatadogMonitoringMetricsGateway() + else: + monitoring_metrics_gateway = FakeMonitoringMetricsGateway() + return monitoring_metrics_gateway + + +def get_monitoring_metrics_gateway() -> MonitoringMetricsGateway: + try: + from plugins.dependencies import ( + get_monitoring_metrics_gateway as get_custom_monitoring_metrics_gateway, + ) + + return get_custom_monitoring_metrics_gateway() + except ModuleNotFoundError: + return get_default_monitoring_metrics_gateway() + finally: + pass + + +def _get_external_interfaces( + read_only: bool, session: Callable[[], AsyncSession] +) -> ExternalInterfaces: + """ + Dependency that returns a ExternalInterfaces object. This allows repositories to share + sessions for the database and redis. + """ + redis_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS) + redis_24h_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS_24H) + sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) + servicebus_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SERVICEBUS) + monitoring_metrics_gateway = get_monitoring_metrics_gateway() + model_endpoint_record_repo = DbModelEndpointRecordRepository( + monitoring_metrics_gateway=monitoring_metrics_gateway, + session=session, + read_only=read_only, + ) + + queue_delegate: QueueEndpointResourceDelegate + if CIRCLECI: + queue_delegate = FakeQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "azure": + queue_delegate = ASBQueueEndpointResourceDelegate() + else: + queue_delegate = SQSQueueEndpointResourceDelegate( + sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) + ) + + inference_task_queue_gateway: TaskQueueGateway + infra_task_queue_gateway: TaskQueueGateway + if CIRCLECI: + inference_task_queue_gateway = redis_24h_task_queue_gateway + infra_task_queue_gateway = redis_task_queue_gateway + elif infra_config().cloud_provider == "azure": + inference_task_queue_gateway = servicebus_task_queue_gateway + infra_task_queue_gateway = servicebus_task_queue_gateway + else: + inference_task_queue_gateway = sqs_task_queue_gateway + infra_task_queue_gateway = sqs_task_queue_gateway + redis_client = aioredis.Redis(connection_pool=get_or_create_aioredis_pool()) + inference_autoscaling_metrics_gateway = ( + ASBInferenceAutoscalingMetricsGateway() + if infra_config().cloud_provider == "azure" + else RedisInferenceAutoscalingMetricsGateway(redis_client=redis_client) + ) # we can just reuse the existing redis client, we shouldn't get key collisions because of the prefix + resource_gateway = LiveEndpointResourceGateway( + queue_delegate=queue_delegate, + inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, + ) + model_endpoint_cache_repo = RedisModelEndpointCacheRepository( + redis_client=redis_client, + ) + model_endpoint_infra_gateway = LiveModelEndpointInfraGateway( + resource_gateway=resource_gateway, + task_queue_gateway=infra_task_queue_gateway, + ) + async_model_endpoint_inference_gateway = LiveAsyncModelEndpointInferenceGateway( + task_queue_gateway=inference_task_queue_gateway + ) + # In CircleCI, we cannot use asyncio because aiohttp cannot connect to the sync endpoints. + sync_model_endpoint_inference_gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=monitoring_metrics_gateway, + use_asyncio=(not CIRCLECI), + ) + streaming_model_endpoint_inference_gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=monitoring_metrics_gateway, + use_asyncio=(not CIRCLECI), + ) + filesystem_gateway = ( + ABSFilesystemGateway() + if infra_config().cloud_provider == "azure" + else S3FilesystemGateway() + ) + llm_artifact_gateway = ( + ABSLLMArtifactGateway() + if infra_config().cloud_provider == "azure" + else S3LLMArtifactGateway() + ) + model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( + filesystem_gateway=filesystem_gateway + ) + model_endpoint_service = LiveModelEndpointService( + model_endpoint_record_repository=model_endpoint_record_repo, + model_endpoint_infra_gateway=model_endpoint_infra_gateway, + model_endpoint_cache_repository=model_endpoint_cache_repo, + async_model_endpoint_inference_gateway=async_model_endpoint_inference_gateway, + streaming_model_endpoint_inference_gateway=streaming_model_endpoint_inference_gateway, + sync_model_endpoint_inference_gateway=sync_model_endpoint_inference_gateway, + model_endpoints_schema_gateway=model_endpoints_schema_gateway, + inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, + can_scale_http_endpoint_from_zero_flag=infra_config().prometheus_server_address is not None, + ) + llm_model_endpoint_service = LiveLLMModelEndpointService( + model_endpoint_record_repository=model_endpoint_record_repo, + model_endpoint_service=model_endpoint_service, + ) + model_bundle_repository = DbModelBundleRepository(session=session, read_only=read_only) + docker_image_batch_job_bundle_repository = DbDockerImageBatchJobBundleRepository( + session=session, read_only=read_only + ) + batch_job_record_repository = DbBatchJobRecordRepository(session=session, read_only=read_only) + trigger_repository = DbTriggerRepository(session=session, read_only=read_only) + batch_job_orchestration_gateway = LiveBatchJobOrchestrationGateway() + batch_job_progress_gateway = LiveBatchJobProgressGateway(filesystem_gateway=filesystem_gateway) + batch_job_service = LiveBatchJobService( + batch_job_record_repository=batch_job_record_repository, + model_endpoint_service=model_endpoint_service, + batch_job_orchestration_gateway=batch_job_orchestration_gateway, + batch_job_progress_gateway=batch_job_progress_gateway, + ) + + model_primitive_gateway = FakeModelPrimitiveGateway() + + docker_image_batch_job_gateway = LiveDockerImageBatchJobGateway() + cron_job_gateway = LiveCronJobGateway() + + llm_fine_tune_repository: LLMFineTuneRepository + file_path = os.getenv( + "CLOUD_FILE_LLM_FINE_TUNE_REPOSITORY", + hmi_config.cloud_file_llm_fine_tune_repository, + ) + if infra_config().cloud_provider == "azure": + llm_fine_tune_repository = ABSFileLLMFineTuneRepository( + file_path=file_path, + ) + else: + llm_fine_tune_repository = S3FileLLMFineTuneRepository( + file_path=file_path, + ) + llm_fine_tune_events_repository = ( + ABSFileLLMFineTuneEventsRepository() + if infra_config().cloud_provider == "azure" + else S3FileLLMFineTuneEventsRepository() + ) + llm_fine_tuning_service = DockerImageBatchJobLLMFineTuningService( + docker_image_batch_job_gateway=docker_image_batch_job_gateway, + docker_image_batch_job_bundle_repo=docker_image_batch_job_bundle_repository, + llm_fine_tune_repository=llm_fine_tune_repository, + ) + + llm_batch_completions_service = LiveLLMBatchCompletionsService( + docker_image_batch_job_gateway=docker_image_batch_job_gateway + ) + + file_storage_gateway = ( + ABSFileStorageGateway() + if infra_config().cloud_provider == "azure" + else S3FileStorageGateway() + ) + + docker_repository: DockerRepository + if CIRCLECI: + docker_repository = FakeDockerRepository() + elif infra_config().docker_repo_prefix.endswith("azurecr.io"): + docker_repository = ACRDockerRepository() + else: + docker_repository = ECRDockerRepository() + + tokenizer_repository = LiveTokenizerRepository(llm_artifact_gateway=llm_artifact_gateway) + + streaming_storage_gateway = FirehoseStreamingStorageGateway() + + external_interfaces = ExternalInterfaces( + docker_repository=docker_repository, + model_bundle_repository=model_bundle_repository, + model_endpoint_service=model_endpoint_service, + llm_model_endpoint_service=llm_model_endpoint_service, + llm_batch_completions_service=llm_batch_completions_service, + batch_job_service=batch_job_service, + resource_gateway=resource_gateway, + endpoint_creation_task_queue_gateway=infra_task_queue_gateway, + inference_task_queue_gateway=inference_task_queue_gateway, + model_endpoint_infra_gateway=model_endpoint_infra_gateway, + model_primitive_gateway=model_primitive_gateway, + docker_image_batch_job_bundle_repository=docker_image_batch_job_bundle_repository, + docker_image_batch_job_gateway=docker_image_batch_job_gateway, + llm_fine_tuning_service=llm_fine_tuning_service, + llm_fine_tune_events_repository=llm_fine_tune_events_repository, + file_storage_gateway=file_storage_gateway, + filesystem_gateway=filesystem_gateway, + llm_artifact_gateway=llm_artifact_gateway, + trigger_repository=trigger_repository, + cron_job_gateway=cron_job_gateway, + monitoring_metrics_gateway=monitoring_metrics_gateway, + tokenizer_repository=tokenizer_repository, + streaming_storage_gateway=streaming_storage_gateway, + ) + return external_interfaces + + +def get_default_external_interfaces() -> ExternalInterfaces: + session = async_scoped_session(get_session_async(), scopefunc=asyncio.current_task) # type: ignore + return _get_external_interfaces(read_only=False, session=session) + + +def get_default_external_interfaces_read_only() -> ExternalInterfaces: + session = async_scoped_session( + get_session_read_only_async(), scopefunc=asyncio.current_task # type: ignore + ) + return _get_external_interfaces(read_only=True, session=session) + + +async def get_external_interfaces(): + try: + from plugins.dependencies import get_external_interfaces as get_custom_external_interfaces + + yield get_custom_external_interfaces() + except ModuleNotFoundError: + yield get_default_external_interfaces() + finally: + pass + + +async def get_external_interfaces_read_only(): + try: + from plugins.dependencies import ( + get_external_interfaces_read_only as get_custom_external_interfaces_read_only, + ) + + yield get_custom_external_interfaces_read_only() + except ModuleNotFoundError: + yield get_default_external_interfaces_read_only() + finally: + pass + + +def get_default_auth_repository() -> AuthenticationRepository: + auth_repo = FakeAuthenticationRepository() + return auth_repo + + +async def get_auth_repository(): + """ + Dependency for an AuthenticationRepository. This implementation returns a fake repository. + """ + try: + from plugins.dependencies import get_auth_repository as get_custom_auth_repository + + yield get_custom_auth_repository() + except ModuleNotFoundError: + yield get_default_auth_repository() + finally: + pass + + +async def verify_authentication( + credentials: Optional[HTTPBasicCredentials] = Depends(basic_auth), + tokens: Optional[str] = Depends(oauth2_scheme), + auth_repo: AuthenticationRepository = Depends(get_auth_repository), +) -> User: + """ + Verifies the authentication headers and returns a (user_id, team_id) auth tuple. Otherwise, + raises a 401. + """ + # Basic Authentication + if credentials is not None: + username = credentials.username if credentials is not None else None + if username is None: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="No authentication was passed in", + headers={"WWW-Authenticate": "Basic"}, + ) + + auth = await auth_repo.get_auth_from_username_async(username=username) + + if not auth: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not authenticate user", + headers={"WWW-Authenticate": "Basic"}, + ) + + # set logger context with identity data + LoggerTagManager.set(LoggerTagKey.USER_ID, auth.user_id) + LoggerTagManager.set(LoggerTagKey.TEAM_ID, auth.team_id) + + return auth + + # bearer token + if tokens is not None: + auth = await auth_repo.get_auth_from_username_async(username=tokens) + if not auth: + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="Could not authenticate user", + headers={"WWW-Authenticate": "Bearer"}, + ) + + # set logger context with identity data + LoggerTagManager.set(LoggerTagKey.USER_ID, auth.user_id) + LoggerTagManager.set(LoggerTagKey.TEAM_ID, auth.team_id) + + return auth + + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail="No authentication was passed in", + headers={"WWW-Authenticate": "Bearer"}, + ) + + +_pool: Optional[aioredis.BlockingConnectionPool] = None + + +def get_or_create_aioredis_pool() -> aioredis.ConnectionPool: + global _pool + + expiration_timestamp = hmi_config.cache_redis_url_expiration_timestamp + if _pool is None or (expiration_timestamp is not None and time.time() > expiration_timestamp): + _pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) + return _pool diff --git a/server/llm_engine_server/api/docker_image_batch_job_bundles_v1.py b/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py similarity index 78% rename from server/llm_engine_server/api/docker_image_batch_job_bundles_v1.py rename to model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py index c39ab0e9..be0b93ad 100644 --- a/server/llm_engine_server/api/docker_image_batch_job_bundles_v1.py +++ b/model-engine/model_engine_server/api/docker_image_batch_job_bundles_v1.py @@ -1,27 +1,26 @@ from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, verify_authentication, ) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.batch_jobs import ( +from model_engine_server.common.dtos.batch_jobs import ( CreateDockerImageBatchJobBundleV1Request, CreateDockerImageBatchJobBundleV1Response, DockerImageBatchJobBundleV1Response, ListDockerImageBatchJobBundleV1Response, ) -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ( + EndpointResourceInvalidRequestException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import EndpointResourceInvalidRequestException -from llm_engine_server.domain.use_cases.docker_image_batch_job_bundle_use_cases import ( +from model_engine_server.domain.use_cases.docker_image_batch_job_bundle_use_cases import ( CreateDockerImageBatchJobBundleV1UseCase, GetDockerImageBatchJobBundleByIdV1UseCase, GetLatestDockerImageBatchJobBundleByNameV1UseCase, @@ -30,12 +29,11 @@ docker_image_batch_job_bundle_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @docker_image_batch_job_bundle_router_v1.post( - "/docker-image-batch-job-bundles", - response_model=CreateDockerImageBatchJobBundleV1Response, + "/docker-image-batch-job-bundles", response_model=CreateDockerImageBatchJobBundleV1Response ) async def create_docker_image_batch_job_bundle( request: CreateDockerImageBatchJobBundleV1Request, @@ -45,7 +43,6 @@ async def create_docker_image_batch_job_bundle( """ Creates a docker iamge batch job bundle """ - add_trace_resource_name("docker_image_batch_job_bundle_post") logger.info(f"POST /docker-image-batch-job-bundles with {request} for {auth}") try: use_case = CreateDockerImageBatchJobBundleV1UseCase( @@ -60,8 +57,7 @@ async def create_docker_image_batch_job_bundle( @docker_image_batch_job_bundle_router_v1.get( - "/docker-image-batch-job-bundles", - response_model=ListDockerImageBatchJobBundleV1Response, + "/docker-image-batch-job-bundles", response_model=ListDockerImageBatchJobBundleV1Response ) async def list_docker_image_batch_job_model_bundles( bundle_name: Optional[str] = Query(default=None), @@ -73,7 +69,6 @@ async def list_docker_image_batch_job_model_bundles( Lists docker image batch job bundles owned by current owner """ - add_trace_resource_name("docker_image_batch_job_bundle_get") logger.info( f"GET /docker-image-batch-job-bundles?bundle_name={bundle_name}&order_by={order_by} for auth" ) @@ -84,8 +79,7 @@ async def list_docker_image_batch_job_model_bundles( @docker_image_batch_job_bundle_router_v1.get( - "/docker-image-batch-job-bundles/latest", - response_model=DockerImageBatchJobBundleV1Response, + "/docker-image-batch-job-bundles/latest", response_model=DockerImageBatchJobBundleV1Response ) async def get_latest_docker_image_batch_job_bundle( bundle_name: str, @@ -93,7 +87,6 @@ async def get_latest_docker_image_batch_job_bundle( external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> DockerImageBatchJobBundleV1Response: """Gets latest Docker Image Batch Job Bundle with given name owned by the current owner""" - add_trace_resource_name("docker_image_batch_job_bundle_latest_get") logger.info(f"GET /docker-image-batch-job-bundles/latest?bundle_name={bundle_name} for {auth}") try: use_case = GetLatestDockerImageBatchJobBundleByNameV1UseCase( @@ -117,7 +110,6 @@ async def get_docker_image_batch_job_model_bundle( external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> DockerImageBatchJobBundleV1Response: """Get details for a given DockerImageBatchJobBundle owned by the current owner""" - add_trace_resource_name("docker_image_batch_job_bundle_id_get") logger.info( f"GET /docker-image-batch-job-bundles/{docker_image_batch_job_bundle_id} for {auth}" ) diff --git a/model-engine/model_engine_server/api/files_v1.py b/model-engine/model_engine_server/api/files_v1.py new file mode 100644 index 00000000..d3c093f0 --- /dev/null +++ b/model-engine/model_engine_server/api/files_v1.py @@ -0,0 +1,121 @@ +"""Files API routes for the hosted model inference service.""" + +from fastapi import APIRouter, Depends, HTTPException, UploadFile +from model_engine_server.api.dependencies import ( + ExternalInterfaces, + get_external_interfaces, + get_external_interfaces_read_only, + verify_authentication, +) +from model_engine_server.common.dtos.files import ( + DeleteFileResponse, + GetFileContentResponse, + GetFileResponse, + ListFilesResponse, + UploadFileResponse, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ( + ObjectNotAuthorizedException, + ObjectNotFoundException, +) +from model_engine_server.domain.use_cases.file_use_cases import ( + DeleteFileUseCase, + GetFileContentUseCase, + GetFileUseCase, + ListFilesUseCase, + UploadFileUseCase, +) + +file_router_v1 = APIRouter(prefix="/v1") +logger = make_logger(logger_name()) + + +@file_router_v1.post("/files", response_model=UploadFileResponse) +async def upload_file( + file: UploadFile, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> UploadFileResponse: + logger.info(f"POST /files with filename {file.filename} for {auth}") + use_case = UploadFileUseCase( + file_storage_gateway=external_interfaces.file_storage_gateway, + ) + return await use_case.execute( + user=auth, + filename=file.filename or "", + content=file.file.read(), + ) + + +@file_router_v1.get("/files/{file_id}", response_model=GetFileResponse) +async def get_file( + file_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> GetFileResponse: + logger.info(f"GET /files/{file_id} for {auth}") + try: + use_case = GetFileUseCase( + file_storage_gateway=external_interfaces.file_storage_gateway, + ) + return await use_case.execute(user=auth, file_id=file_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified file could not be found.", + ) from exc + + +@file_router_v1.get("/files", response_model=ListFilesResponse) +async def list_files( + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> ListFilesResponse: + logger.info(f"GET /files for {auth}") + use_case = ListFilesUseCase( + file_storage_gateway=external_interfaces.file_storage_gateway, + ) + return await use_case.execute(user=auth) + + +@file_router_v1.delete("/files/{file_id}", response_model=DeleteFileResponse) +async def delete_file( + file_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> DeleteFileResponse: + logger.info(f"DELETE /files/{file_id} for {auth}") + try: + use_case = DeleteFileUseCase( + file_storage_gateway=external_interfaces.file_storage_gateway, + ) + return await use_case.execute(user=auth, file_id=file_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified file could not be found.", + ) from exc + + +@file_router_v1.get("/files/{file_id}/content", response_model=GetFileContentResponse) +async def get_file_content( + file_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> GetFileContentResponse: + """ + Describe the LLM Model endpoint with given name. + """ + logger.info(f"GET /files/{file_id}/content for {auth}") + try: + use_case = GetFileContentUseCase( + file_storage_gateway=external_interfaces.file_storage_gateway, + ) + return await use_case.execute(user=auth, file_id=file_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified file could not be found.", + ) from exc diff --git a/model-engine/model_engine_server/api/llms_v1.py b/model-engine/model_engine_server/api/llms_v1.py new file mode 100644 index 00000000..a52d81c6 --- /dev/null +++ b/model-engine/model_engine_server/api/llms_v1.py @@ -0,0 +1,655 @@ +"""LLM Model Endpoint routes for the hosted model inference service. +""" + +import traceback +from datetime import datetime +from typing import Optional + +import pytz +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request +from model_engine_server.api.dependencies import ( + ExternalInterfaces, + get_external_interfaces, + get_external_interfaces_read_only, + verify_authentication, +) +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.llms import ( + CancelFineTuneResponse, + CompletionStreamV1Request, + CompletionStreamV1Response, + CompletionSyncV1Request, + CompletionSyncV1Response, + CreateBatchCompletionsV1Request, + CreateBatchCompletionsV1Response, + CreateFineTuneRequest, + CreateFineTuneResponse, + CreateLLMModelEndpointV1Request, + CreateLLMModelEndpointV1Response, + DeleteLLMEndpointResponse, + GetFineTuneEventsResponse, + GetFineTuneResponse, + GetLLMModelEndpointV1Response, + ListFineTunesResponse, + ListLLMModelEndpointsV1Response, + ModelDownloadRequest, + ModelDownloadResponse, + StreamError, + StreamErrorContent, + TokenUsage, + UpdateLLMModelEndpointV1Request, + UpdateLLMModelEndpointV1Response, +) +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) +from model_engine_server.core.utils.timer import timer +from model_engine_server.domain.exceptions import ( + DockerImageNotFoundException, + EndpointDeleteFailedException, + EndpointLabelsException, + EndpointResourceInvalidRequestException, + EndpointUnsupportedInferenceTypeException, + ExistingEndpointOperationInProgressException, + FailToInferHardwareException, + InvalidRequestException, + LLMFineTuningMethodNotImplementedException, + LLMFineTuningQuotaReached, + ObjectAlreadyExistsException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, + UpstreamServiceError, +) +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MetricMetadata +from model_engine_server.domain.use_cases.llm_fine_tuning_use_cases import ( + CancelFineTuneV1UseCase, + CreateFineTuneV1UseCase, + GetFineTuneEventsV1UseCase, + GetFineTuneV1UseCase, + ListFineTunesV1UseCase, +) +from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( + CompletionStreamV1UseCase, + CompletionSyncV1UseCase, + CreateBatchCompletionsUseCase, + CreateLLMModelBundleV1UseCase, + CreateLLMModelEndpointV1UseCase, + DeleteLLMEndpointByNameUseCase, + GetLLMModelEndpointByNameV1UseCase, + ListLLMModelEndpointsV1UseCase, + ModelDownloadV1UseCase, + UpdateLLMModelEndpointV1UseCase, +) +from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase +from sse_starlette.sse import EventSourceResponse + + +def format_request_route(request: Request) -> str: + url_path = request.url.path + for path_param in request.path_params: + url_path = url_path.replace(request.path_params[path_param], f":{path_param}") + return f"{request.method}_{url_path}".lower() + + +async def get_metric_metadata( + request: Request, + auth: User = Depends(verify_authentication), +) -> MetricMetadata: + model_name = request.query_params.get("model_endpoint_name", None) + return MetricMetadata(user=auth, model_name=model_name) + + +async def record_route_call( + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + route: str = Depends(format_request_route), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), +): + external_interfaces.monitoring_metrics_gateway.emit_route_call_metric(route, metric_metadata) + + +llm_router_v1 = APIRouter(prefix="/v1/llm", dependencies=[Depends(record_route_call)]) +logger = make_logger(logger_name()) + + +def handle_streaming_exception( + e: Exception, + code: int, + message: str, +): + tb_str = traceback.format_exception(e) + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + structured_log = { + "error": message, + "request_id": str(request_id), + "traceback": "".join(tb_str), + } + logger.error("Exception: %s", structured_log) + return { + "data": CompletionStreamV1Response( + request_id=str(request_id), + error=StreamError( + status_code=code, + content=StreamErrorContent( + error=message, + timestamp=timestamp, + ), + ), + ).json() + } + + +@llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response) +async def create_model_endpoint( + request: CreateLLMModelEndpointV1Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> CreateLLMModelEndpointV1Response: + """ + Creates an LLM endpoint for the current user. + """ + logger.info(f"POST /llm/model-endpoints with {request} for {auth}") + try: + create_model_bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=external_interfaces.model_bundle_repository, + docker_repository=external_interfaces.docker_repository, + model_primitive_gateway=external_interfaces.model_primitive_gateway, + ) + create_llm_model_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=create_model_bundle_use_case, + model_bundle_repository=external_interfaces.model_bundle_repository, + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + docker_repository=external_interfaces.docker_repository, + ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, + model_endpoint_service=external_interfaces.model_endpoint_service, + docker_repository=external_interfaces.docker_repository, + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + ) + return await use_case.execute(user=auth, request=request) + except ObjectAlreadyExistsException as exc: + raise HTTPException( + status_code=400, + detail="The specified model endpoint already exists.", + ) from exc + except EndpointLabelsException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=str(exc)) + except EndpointResourceInvalidRequestException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified model bundle could not be found.", + ) from exc + except DockerImageNotFoundException as exc: + raise HTTPException( + status_code=404, + detail="The specified docker image could not be found.", + ) from exc + except FailToInferHardwareException as exc: + raise HTTPException( + status_code=500, + detail="Failed to infer hardware exception.", + ) from exc + + +@llm_router_v1.get("/model-endpoints", response_model=ListLLMModelEndpointsV1Response) +async def list_model_endpoints( + name: Optional[str] = Query(default=None), + order_by: Optional[ModelEndpointOrderBy] = Query(default=None), + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> ListLLMModelEndpointsV1Response: + """ + Lists the LLM model endpoints owned by the current owner, plus all public_inference LLMs. + """ + logger.info(f"GET /llm/model-endpoints?name={name}&order_by={order_by} for {auth}") + use_case = ListLLMModelEndpointsV1UseCase( + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + ) + return await use_case.execute(user=auth, name=name, order_by=order_by) + + +@llm_router_v1.get( + "/model-endpoints/{model_endpoint_name}", + response_model=GetLLMModelEndpointV1Response, +) +async def get_model_endpoint( + model_endpoint_name: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> GetLLMModelEndpointV1Response: + """ + Describe the LLM Model endpoint with given name. + """ + logger.info(f"GET /llm/model-endpoints/{model_endpoint_name} for {auth}") + try: + use_case = GetLLMModelEndpointByNameV1UseCase( + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service + ) + return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + if isinstance(exc, ObjectNotAuthorizedException): # pragma: no cover + logger.info( + f"GET /llm/model-endpoints/{model_endpoint_name} for {auth} failed with authz error {exc.args}" + ) + + raise HTTPException( + status_code=404, + detail=f"Model Endpoint {model_endpoint_name} was not found.", + ) from exc + + +@llm_router_v1.put( + "/model-endpoints/{model_endpoint_name}", + response_model=UpdateLLMModelEndpointV1Response, +) +async def update_model_endpoint( + model_endpoint_name: str, + request: UpdateLLMModelEndpointV1Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> UpdateLLMModelEndpointV1Response: + """ + Updates an LLM endpoint for the current user. + """ + logger.info(f"PUT /llm/model-endpoints/{model_endpoint_name} with {request} for {auth}") + try: + create_model_bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=external_interfaces.model_bundle_repository, + docker_repository=external_interfaces.docker_repository, + model_primitive_gateway=external_interfaces.model_primitive_gateway, + ) + create_llm_model_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=create_model_bundle_use_case, + model_bundle_repository=external_interfaces.model_bundle_repository, + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + docker_repository=external_interfaces.docker_repository, + ) + use_case = UpdateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=create_llm_model_bundle_use_case, + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + docker_repository=external_interfaces.docker_repository, + ) + return await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + except EndpointLabelsException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=str(exc)) + except EndpointResourceInvalidRequestException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified LLM endpoint could not be found.", + ) from exc + except DockerImageNotFoundException as exc: + raise HTTPException( + status_code=404, + detail="The specified docker image could not be found.", + ) from exc + + +@llm_router_v1.post("/completions-sync", response_model=CompletionSyncV1Response) +async def create_completion_sync_task( + model_endpoint_name: str, + request: CompletionSyncV1Request, + background_tasks: BackgroundTasks, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), +) -> CompletionSyncV1Response: + """ + Runs a sync prompt completion on an LLM. + """ + if hmi_config.sensitive_log_mode: # pragma: no cover + logger.info(f"POST /completions-sync to endpoint {model_endpoint_name} for {auth}") + else: + logger.info( + f"POST /completions-sync with {request} to endpoint {model_endpoint_name} for {auth}" + ) + try: + use_case = CompletionSyncV1UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + tokenizer_repository=external_interfaces.tokenizer_repository, + ) + with timer() as use_case_timer: + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=(response.output.num_prompt_tokens if response.output else None), + num_completion_tokens=( + response.output.num_completion_tokens if response.output else None + ), + total_duration=use_case_timer.duration, + ), + metric_metadata, + ) + return response + except UpstreamServiceError as exc: + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.exception( + f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" + ) + raise HTTPException( + status_code=500, + detail=f"Upstream service error for request_id {request_id}", + ) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + if isinstance(exc, ObjectNotAuthorizedException): # pragma: no cover + logger.info( + f"POST /completions-sync to endpoint {model_endpoint_name} for {auth} failed with authz error {exc.args}" + ) + + raise HTTPException( + status_code=404, + detail="The specified endpoint could not be found.", + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=str(exc)) + except InvalidRequestException as exc: + raise HTTPException(status_code=400, detail=str(exc)) + except EndpointUnsupportedInferenceTypeException as exc: + raise HTTPException( + status_code=400, + detail=f"Unsupported inference type: {str(exc)}", + ) from exc + + +@llm_router_v1.post("/completions-stream", response_model=CompletionStreamV1Response) +async def create_completion_stream_task( + model_endpoint_name: str, + request: CompletionStreamV1Request, + background_tasks: BackgroundTasks, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), +) -> EventSourceResponse: + """ + Runs a stream prompt completion on an LLM. + """ + if hmi_config.sensitive_log_mode: # pragma: no cover + logger.info(f"POST /completions-stream to endpoint {model_endpoint_name} for {auth}") + else: + logger.info( + f"POST /completions-stream with {request} to endpoint {model_endpoint_name} for {auth}" + ) + use_case = CompletionStreamV1UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + tokenizer_repository=external_interfaces.tokenizer_repository, + ) + + try: + # Call execute() with await, since it needs to handle exceptions before we begin streaming the response below. + # execute() will create a response chunk generator and return a reference to it. + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + except EndpointUnsupportedInferenceTypeException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + raise HTTPException( + status_code=500, + detail="Internal error occurred. Our team has been notified.", + ) from exc + + async def event_generator(): + try: + time_to_first_token = None + with timer() as use_case_timer: + async for message in response: + if time_to_first_token is None and message.output is not None: + time_to_first_token = use_case_timer.lap() + yield {"data": message.json()} + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=( + message.output.num_prompt_tokens if message.output else None + ), + num_completion_tokens=( + message.output.num_completion_tokens if message.output else None + ), + total_duration=use_case_timer.duration, + time_to_first_token=time_to_first_token, + ), + metric_metadata, + ) + # The following two exceptions are only raised after streaming begins, so we wrap the exception within a Response object + except InvalidRequestException as exc: + yield handle_streaming_exception(exc, 400, str(exc)) + except UpstreamServiceError as exc: + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.exception( + f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" + ) + yield handle_streaming_exception( + exc, + 500, + f"Upstream service error for request_id {request_id}", + ) + except Exception as exc: + yield handle_streaming_exception( + exc, 500, "Internal error occurred. Our team has been notified." + ) + + return EventSourceResponse(event_generator()) + + +@llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneResponse) +async def create_fine_tune( + request: CreateFineTuneRequest, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> CreateFineTuneResponse: + logger.info(f"POST /fine-tunes with {request} for {auth}") + try: + use_case = CreateFineTuneV1UseCase( + llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_fine_tune_events_repository=external_interfaces.llm_fine_tune_events_repository, + file_storage_gateway=external_interfaces.file_storage_gateway, + ) + return await use_case.execute(user=auth, request=request) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException(status_code=404, detail=str(exc)) from exc + except ( + LLMFineTuningMethodNotImplementedException, + LLMFineTuningQuotaReached, + InvalidRequestException, + ) as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + + +@llm_router_v1.get("/fine-tunes/{fine_tune_id}", response_model=GetFineTuneResponse) +async def get_fine_tune( + fine_tune_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> GetFineTuneResponse: + logger.info(f"GET /fine-tunes/{fine_tune_id} for {auth}") + try: + use_case = GetFineTuneV1UseCase( + llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, + ) + return await use_case.execute(user=auth, fine_tune_id=fine_tune_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified fine-tune job could not be found.", + ) from exc + + +@llm_router_v1.get("/fine-tunes", response_model=ListFineTunesResponse) +async def list_fine_tunes( + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> ListFineTunesResponse: + logger.info(f"GET /fine-tunes for {auth}") + use_case = ListFineTunesV1UseCase( + llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, + ) + return await use_case.execute(user=auth) + + +@llm_router_v1.put("/fine-tunes/{fine_tune_id}/cancel", response_model=CancelFineTuneResponse) +async def cancel_fine_tune( + fine_tune_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> CancelFineTuneResponse: + logger.info(f"PUT /fine-tunes/{fine_tune_id}/cancel for {auth}") + try: + use_case = CancelFineTuneV1UseCase( + llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, + ) + return await use_case.execute(user=auth, fine_tune_id=fine_tune_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified fine-tune job could not be found.", + ) from exc + + +@llm_router_v1.get("/fine-tunes/{fine_tune_id}/events", response_model=GetFineTuneEventsResponse) +async def get_fine_tune_events( + fine_tune_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> GetFineTuneEventsResponse: + logger.info(f"GET /fine-tunes/{fine_tune_id}/events for {auth}") + try: + use_case = GetFineTuneEventsV1UseCase( + llm_fine_tune_events_repository=external_interfaces.llm_fine_tune_events_repository, + llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, + ) + return await use_case.execute(user=auth, fine_tune_id=fine_tune_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified fine-tune job's events could not be found.", + ) from exc + + +@llm_router_v1.post("/model-endpoints/download", response_model=ModelDownloadResponse) +async def download_model_endpoint( + request: ModelDownloadRequest, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> ModelDownloadResponse: + logger.info(f"POST /model-endpoints/download with {request} for {auth}") + try: + use_case = ModelDownloadV1UseCase( + filesystem_gateway=external_interfaces.filesystem_gateway, + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + ) + return await use_case.execute(user=auth, request=request) + except (ObjectNotFoundException, ObjectHasInvalidValueException) as exc: + raise HTTPException( + status_code=404, + detail="The requested fine-tuned model could not be found.", + ) from exc + + +@llm_router_v1.delete( + "/model-endpoints/{model_endpoint_name}", response_model=DeleteLLMEndpointResponse +) +async def delete_llm_model_endpoint( + model_endpoint_name: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> DeleteLLMEndpointResponse: + logger.info(f"DELETE /model-endpoints/{model_endpoint_name} for {auth}") + try: + use_case = DeleteLLMEndpointByNameUseCase( + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + model_endpoint_service=external_interfaces.model_endpoint_service, + ) + return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name) + except ObjectNotFoundException as exc: + raise HTTPException( + status_code=404, + detail="The requested model endpoint could not be found.", + ) from exc + except ObjectNotAuthorizedException as exc: + raise HTTPException( + status_code=403, + detail="You don't have permission to delete the requested model endpoint.", + ) from exc + except ExistingEndpointOperationInProgressException as exc: + raise HTTPException( + status_code=409, + detail="Existing operation on endpoint in progress, try again later.", + ) from exc + except EndpointDeleteFailedException as exc: # pragma: no cover + raise HTTPException( + status_code=500, + detail="deletion of endpoint failed.", + ) from exc + + +@llm_router_v1.post("/batch-completions", response_model=CreateBatchCompletionsV1Response) +async def create_batch_completions( + request: CreateBatchCompletionsV1Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> CreateBatchCompletionsV1Response: + logger.info(f"POST /batch-completions with {request} for {auth}") + try: + use_case = CreateBatchCompletionsUseCase( + docker_image_batch_job_gateway=external_interfaces.docker_image_batch_job_gateway, + docker_repository=external_interfaces.docker_repository, + docker_image_batch_job_bundle_repo=external_interfaces.docker_image_batch_job_bundle_repository, + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + ) + return await use_case.execute(user=auth, request=request) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail="The specified endpoint could not be found.", + ) from exc + except (InvalidRequestException, ObjectHasInvalidValueException) as exc: + raise HTTPException(status_code=400, detail=str(exc)) diff --git a/server/llm_engine_server/api/model_bundles_v1.py b/model-engine/model_engine_server/api/model_bundles_v1.py similarity index 87% rename from server/llm_engine_server/api/model_bundles_v1.py rename to model-engine/model_engine_server/api/model_bundles_v1.py index efcf43ab..de73fb4c 100644 --- a/server/llm_engine_server/api/model_bundles_v1.py +++ b/model-engine/model_engine_server/api/model_bundles_v1.py @@ -1,16 +1,15 @@ -"""Model Bundle v1 routes for the LLMEngine service.""" +"""Model Bundle v1 routes for the hosted model inference service.""" from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.model_bundles import ( +from model_engine_server.common.dtos.model_bundles import ( CloneModelBundleV1Request, CreateModelBundleV1Request, CreateModelBundleV1Response, @@ -18,15 +17,15 @@ ModelBundleOrderBy, ModelBundleV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.use_cases.model_bundle_use_cases import ( +from model_engine_server.domain.use_cases.model_bundle_use_cases import ( CloneModelBundleV1UseCase, CreateModelBundleV1UseCase, GetLatestModelBundleByNameV1UseCase, @@ -35,7 +34,7 @@ ) model_bundle_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @model_bundle_router_v1.post("/model-bundles", response_model=CreateModelBundleV1Response) @@ -48,7 +47,6 @@ async def create_model_bundle( Creates a ModelBundle for the current user. """ logger.info(f"POST /model-bundles with {request} for {auth}") - add_trace_resource_name("model_bundles_post") try: use_case = CreateModelBundleV1UseCase( model_bundle_repository=external_interfaces.model_bundle_repository, @@ -84,7 +82,6 @@ async def clone_model_bundle_with_changes( """ Creates a ModelBundle by cloning an existing one and then applying changes on top. """ - add_trace_resource_name("model_bundles_clone") try: use_case = CloneModelBundleV1UseCase( model_bundle_repository=external_interfaces.model_bundle_repository, @@ -107,7 +104,6 @@ async def list_model_bundles( """ Lists the ModelBundles owned by the current owner. """ - add_trace_resource_name("model_bundles_get") logger.info(f"GET /model-bundles?model_name={model_name}&order_by={order_by} for {auth}") use_case = ListModelBundlesV1UseCase( model_bundle_repository=external_interfaces.model_bundle_repository @@ -124,7 +120,6 @@ async def get_latest_model_bundle( """ Gets the latest Model Bundle with the given name owned by the current owner. """ - add_trace_resource_name("model_bundles_latest_get") logger.info(f"GET /model-bundles/latest?model_name={model_name} for {auth}") try: use_case = GetLatestModelBundleByNameV1UseCase( @@ -149,7 +144,6 @@ async def get_model_bundle( """ Gets the details for a given ModelBundle owned by the current owner. """ - add_trace_resource_name("model_bundles_id_get") logger.info(f"GET /model-bundles/{model_bundle_id} for {auth}") try: use_case = GetModelBundleByIdV1UseCase( diff --git a/server/llm_engine_server/api/model_bundles_v2.py b/model-engine/model_engine_server/api/model_bundles_v2.py similarity index 88% rename from server/llm_engine_server/api/model_bundles_v2.py rename to model-engine/model_engine_server/api/model_bundles_v2.py index 00d4ffed..3376de70 100644 --- a/server/llm_engine_server/api/model_bundles_v2.py +++ b/model-engine/model_engine_server/api/model_bundles_v2.py @@ -3,14 +3,13 @@ from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.model_bundles import ( +from model_engine_server.common.dtos.model_bundles import ( CloneModelBundleV2Request, CreateModelBundleV2Request, CreateModelBundleV2Response, @@ -18,15 +17,15 @@ ModelBundleOrderBy, ModelBundleV2Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.use_cases.model_bundle_use_cases import ( +from model_engine_server.domain.use_cases.model_bundle_use_cases import ( CloneModelBundleV2UseCase, CreateModelBundleV2UseCase, GetLatestModelBundleByNameV2UseCase, @@ -35,7 +34,7 @@ ) model_bundle_router_v2 = APIRouter(prefix="/v2") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @model_bundle_router_v2.post("/model-bundles", response_model=CreateModelBundleV2Response) @@ -48,7 +47,6 @@ async def create_model_bundle( Creates a ModelBundle for the current user. """ logger.info(f"POST /model-bundles with {request} for {auth}") - add_trace_resource_name("model_bundles_post") try: use_case = CreateModelBundleV2UseCase( model_bundle_repository=external_interfaces.model_bundle_repository, @@ -84,7 +82,6 @@ async def clone_model_bundle_with_changes( """ Creates a ModelBundle by cloning an existing one and then applying changes on top. """ - add_trace_resource_name("model_bundles_clone") try: use_case = CloneModelBundleV2UseCase( model_bundle_repository=external_interfaces.model_bundle_repository, @@ -107,7 +104,6 @@ async def list_model_bundles( """ Lists the ModelBundles owned by the current owner. """ - add_trace_resource_name("model_bundles_get") logger.info(f"GET /model-bundles?model_name={model_name}&order_by={order_by} for {auth}") use_case = ListModelBundlesV2UseCase( model_bundle_repository=external_interfaces.model_bundle_repository @@ -124,7 +120,6 @@ async def get_latest_model_bundle( """ Gets the latest Model Bundle with the given name owned by the current owner. """ - add_trace_resource_name("model_bundles_latest_get") logger.info(f"GET /model-bundles/latest?model_name={model_name} for {auth}") try: use_case = GetLatestModelBundleByNameV2UseCase( @@ -149,7 +144,6 @@ async def get_model_bundle( """ Gets the details for a given ModelBundle owned by the current owner. """ - add_trace_resource_name("model_bundles_id_get") logger.info(f"GET /model-bundles/{model_bundle_id} for {auth}") try: use_case = GetModelBundleByIdV2UseCase( diff --git a/server/llm_engine_server/api/model_endpoints_docs_v1.py b/model-engine/model_engine_server/api/model_endpoints_docs_v1.py similarity index 82% rename from server/llm_engine_server/api/model_endpoints_docs_v1.py rename to model-engine/model_engine_server/api/model_endpoints_docs_v1.py index 5ccb8a30..f4f2d734 100644 --- a/server/llm_engine_server/api/model_endpoints_docs_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_docs_v1.py @@ -2,20 +2,20 @@ from fastapi.encoders import jsonable_encoder from fastapi.openapi.docs import get_redoc_html from fastapi.responses import JSONResponse -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.use_cases.model_endpoints_schema_use_cases import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.use_cases.model_endpoints_schema_use_cases import ( GetModelEndpointsSchemaV1UseCase, ) from starlette.responses import HTMLResponse model_endpoints_docs_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @model_endpoints_docs_router_v1.get("/model-endpoints-schema.json") diff --git a/server/llm_engine_server/api/model_endpoints_v1.py b/model-engine/model_engine_server/api/model_endpoints_v1.py similarity index 81% rename from server/llm_engine_server/api/model_endpoints_v1.py rename to model-engine/model_engine_server/api/model_endpoints_v1.py index a1e28df3..fd3a06a4 100644 --- a/server/llm_engine_server/api/model_endpoints_v1.py +++ b/model-engine/model_engine_server/api/model_endpoints_v1.py @@ -3,17 +3,17 @@ List model endpoint history: GET model-endpoints//history Read model endpoint creation logs: GET model-endpoints//creation-logs """ + from typing import Optional from fastapi import APIRouter, Depends, HTTPException, Query -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.model_endpoints import ( +from model_engine_server.common.dtos.model_endpoints import ( CreateModelEndpointV1Request, CreateModelEndpointV1Response, DeleteModelEndpointV1Response, @@ -23,22 +23,21 @@ UpdateModelEndpointV1Request, UpdateModelEndpointV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( - ObjectAlreadyExistsException, - ObjectHasInvalidValueException, - ObjectNotApprovedException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ( EndpointDeleteFailedException, + EndpointInfraStateNotFound, EndpointLabelsException, EndpointResourceInvalidRequestException, ExistingEndpointOperationInProgressException, + ObjectAlreadyExistsException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, + PostInferenceHooksException, ) -from llm_engine_server.domain.use_cases.model_endpoint_use_cases import ( +from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( CreateModelEndpointV1UseCase, DeleteModelEndpointByIdV1UseCase, GetModelEndpointByIdV1UseCase, @@ -47,7 +46,7 @@ ) model_endpoint_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @model_endpoint_router_v1.post("/model-endpoints", response_model=CreateModelEndpointV1Response) @@ -59,7 +58,6 @@ async def create_model_endpoint( """ Creates a Model for the current user. """ - add_trace_resource_name("model_endpoints_post") logger.info(f"POST /model-endpoints with {request} for {auth}") try: use_case = CreateModelEndpointV1UseCase( @@ -72,23 +70,15 @@ async def create_model_endpoint( status_code=400, detail="The specified model endpoint already exists.", ) from exc - except EndpointLabelsException as exc: - raise HTTPException( - status_code=400, - detail=str(exc), - ) from exc - except ObjectHasInvalidValueException as exc: - raise HTTPException(status_code=400, detail=str(exc)) - except EndpointResourceInvalidRequestException as exc: + except ( + EndpointLabelsException, + ObjectHasInvalidValueException, + EndpointResourceInvalidRequestException, + ) as exc: raise HTTPException( status_code=400, detail=str(exc), ) from exc - except ObjectNotApprovedException as exc: - raise HTTPException( - status_code=403, - detail="The specified model bundle was not approved yet.", - ) from exc except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: raise HTTPException( status_code=404, @@ -106,7 +96,6 @@ async def list_model_endpoints( """ Lists the Models owned by the current owner. """ - add_trace_resource_name("model_endpoints_get") logger.info(f"GET /model-endpoints?name={name}&order_by={order_by} for {auth}") use_case = ListModelEndpointsV1UseCase( model_endpoint_service=external_interfaces.model_endpoint_service, @@ -125,7 +114,6 @@ async def get_model_endpoint( """ Describe the Model endpoint with given ID. """ - add_trace_resource_name("model_endpoints_id_get") logger.info(f"GET /model-endpoints/{model_endpoint_id} for {auth}") try: use_case = GetModelEndpointByIdV1UseCase( @@ -149,9 +137,8 @@ async def update_model_endpoint( external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), ) -> UpdateModelEndpointV1Response: """ - Lists the Models owned by the current owner. + Updates the Model endpoint. """ - add_trace_resource_name("model_endpoints_id_put") logger.info(f"PUT /model-endpoints/{model_endpoint_id} with {request} for {auth}") try: use_case = UpdateModelEndpointByIdV1UseCase( @@ -161,12 +148,12 @@ async def update_model_endpoint( return await use_case.execute( user=auth, model_endpoint_id=model_endpoint_id, request=request ) - except ObjectNotApprovedException as exc: - raise HTTPException( - status_code=403, - detail="The specified model bundle was not approved yet.", - ) from exc - except EndpointLabelsException as exc: + except ( + EndpointLabelsException, + ObjectHasInvalidValueException, + EndpointResourceInvalidRequestException, + PostInferenceHooksException, + ) as exc: raise HTTPException( status_code=400, detail=str(exc), @@ -181,6 +168,11 @@ async def update_model_endpoint( status_code=409, detail="Existing operation on endpoint in progress, try again later.", ) from exc + except EndpointInfraStateNotFound as exc: + raise HTTPException( + status_code=500, + detail="Endpoint infra state not found, try again later.", + ) from exc @model_endpoint_router_v1.delete( @@ -194,7 +186,6 @@ async def delete_model_endpoint( """ Lists the Models owned by the current owner. """ - add_trace_resource_name("model_endpoints_id_delete") logger.info(f"DELETE /model-endpoints/{model_endpoint_id} for {auth}") try: use_case = DeleteModelEndpointByIdV1UseCase( diff --git a/server/llm_engine_server/api/tasks_v1.py b/model-engine/model_engine_server/api/tasks_v1.py similarity index 84% rename from server/llm_engine_server/api/tasks_v1.py rename to model-engine/model_engine_server/api/tasks_v1.py index e0318d94..663b3e0c 100644 --- a/server/llm_engine_server/api/tasks_v1.py +++ b/model-engine/model_engine_server/api/tasks_v1.py @@ -1,41 +1,42 @@ +import asyncio + from fastapi import APIRouter, Depends, HTTPException -from llm_engine_server.api.dependencies import ( +from model_engine_server.api.dependencies import ( ExternalInterfaces, get_external_interfaces_read_only, verify_authentication, ) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, EndpointPredictV1Request, GetAsyncTaskV1Response, + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, TaskStatus, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, + InvalidRequestException, ObjectNotAuthorizedException, ObjectNotFoundException, -) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import ( - EndpointUnsupportedInferenceTypeException, UpstreamServiceError, ) -from llm_engine_server.domain.use_cases.async_inference_use_cases import ( +from model_engine_server.domain.use_cases.async_inference_use_cases import ( CreateAsyncInferenceTaskV1UseCase, GetAsyncInferenceTaskV1UseCase, ) -from llm_engine_server.domain.use_cases.streaming_inference_use_cases import ( +from model_engine_server.domain.use_cases.streaming_inference_use_cases import ( CreateStreamingInferenceTaskV1UseCase, ) -from llm_engine_server.domain.use_cases.sync_inference_use_cases import ( +from model_engine_server.domain.use_cases.sync_inference_use_cases import ( CreateSyncInferenceTaskV1UseCase, ) from sse_starlette.sse import EventSourceResponse inference_task_router_v1 = APIRouter(prefix="/v1") -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) @inference_task_router_v1.post("/async-tasks", response_model=CreateAsyncTaskV1Response) @@ -48,7 +49,6 @@ async def create_async_inference_task( """ Runs an async inference prediction. """ - add_trace_resource_name("task_async_post") logger.info(f"POST /async-tasks {request} to endpoint {model_endpoint_id} for {auth}") try: use_case = CreateAsyncInferenceTaskV1UseCase( @@ -67,6 +67,11 @@ async def create_async_inference_task( status_code=400, detail=f"Unsupported inference type: {str(exc)}", ) from exc + except InvalidRequestException as exc: + raise HTTPException( + status_code=400, + detail=f"Invalid request: {str(exc)}", + ) from exc @inference_task_router_v1.get("/async-tasks/{task_id}", response_model=GetAsyncTaskV1Response) @@ -78,7 +83,6 @@ def get_async_inference_task( """ Gets the status of an async inference task. """ - add_trace_resource_name("task_async_id_get") logger.info(f"GET /async-tasks/{task_id} for {auth}") try: use_case = GetAsyncInferenceTaskV1UseCase( @@ -95,14 +99,13 @@ def get_async_inference_task( @inference_task_router_v1.post("/sync-tasks", response_model=SyncEndpointPredictV1Response) async def create_sync_inference_task( model_endpoint_id: str, - request: EndpointPredictV1Request, + request: SyncEndpointPredictV1Request, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), ) -> SyncEndpointPredictV1Response: """ Runs a sync inference prediction. """ - add_trace_resource_name("task_sync_post") logger.info(f"POST /sync-tasks with {request} to endpoint {model_endpoint_id} for {auth}") try: use_case = CreateSyncInferenceTaskV1UseCase( @@ -125,19 +128,23 @@ async def create_sync_inference_task( status_code=400, detail=f"Unsupported inference type: {str(exc)}", ) from exc + except asyncio.exceptions.TimeoutError as exc: + raise HTTPException( + status_code=408, + detail="Request timed out.", + ) from exc @inference_task_router_v1.post("/streaming-tasks") async def create_streaming_inference_task( model_endpoint_id: str, - request: EndpointPredictV1Request, + request: SyncEndpointPredictV1Request, auth: User = Depends(verify_authentication), external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), ) -> EventSourceResponse: """ Runs a streaming inference prediction. """ - add_trace_resource_name("task_streaming_post") logger.info(f"POST /streaming-tasks with {request} to endpoint {model_endpoint_id} for {auth}") try: use_case = CreateStreamingInferenceTaskV1UseCase( diff --git a/model-engine/model_engine_server/api/triggers_v1.py b/model-engine/model_engine_server/api/triggers_v1.py new file mode 100644 index 00000000..010140af --- /dev/null +++ b/model-engine/model_engine_server/api/triggers_v1.py @@ -0,0 +1,168 @@ +from fastapi import APIRouter, Depends, HTTPException +from model_engine_server.api.dependencies import ( + ExternalInterfaces, + get_external_interfaces, + verify_authentication, +) +from model_engine_server.common.dtos.triggers import ( + CreateTriggerV1Request, + CreateTriggerV1Response, + DeleteTriggerV1Response, + GetTriggerV1Response, + ListTriggersV1Response, + UpdateTriggerV1Request, + UpdateTriggerV1Response, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ( + CronSyntaxException, + DockerImageNotFoundException, + EndpointLabelsException, + EndpointResourceInvalidRequestException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, + TriggerNameAlreadyExistsException, +) +from model_engine_server.domain.use_cases.trigger_use_cases import ( + CreateTriggerUseCase, + DeleteTriggerUseCase, + GetTriggerUseCase, + ListTriggersUseCase, + UpdateTriggerUseCase, +) + +trigger_router_v1 = APIRouter(prefix="/v1") + +logger = make_logger(logger_name()) + + +@trigger_router_v1.post("/triggers", response_model=CreateTriggerV1Response) +async def create_trigger( + request: CreateTriggerV1Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> CreateTriggerV1Response: + """ + Creates and runs a trigger + """ + logger.info(f"POST /triggers with {request} for {auth}") + try: + use_case = CreateTriggerUseCase( + trigger_repository=external_interfaces.trigger_repository, + cron_job_gateway=external_interfaces.cron_job_gateway, + docker_image_batch_job_bundle_repository=external_interfaces.docker_image_batch_job_bundle_repository, + docker_repository=external_interfaces.docker_repository, + ) + return await use_case.execute(user=auth, request=request) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, detail="The specified batch job bundle could not be found" + ) from exc + except DockerImageNotFoundException as exc: + raise HTTPException( + status_code=404, + detail=f"The specified docker image {exc.repository}:{exc.tag} was not found", + ) + except ObjectHasInvalidValueException as exc: + raise HTTPException( + status_code=400, + detail=f"The user specified an invalid value: {exc}", + ) from exc + except EndpointResourceInvalidRequestException as exc: + raise HTTPException( + status_code=400, + detail=f"Default trigger resource request is invalid: {exc}", + ) + except EndpointLabelsException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except CronSyntaxException as exc: + raise HTTPException( + status_code=400, + detail=f"The user specified an invalid value for cron_schedule: {exc}", + ) + except TriggerNameAlreadyExistsException as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + + +@trigger_router_v1.get("/triggers", response_model=ListTriggersV1Response) +async def list_triggers( + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> ListTriggersV1Response: + """ + Lists descriptions of all triggers + """ + logger.info(f"GET /triggers for {auth}") + use_case = ListTriggersUseCase(trigger_repository=external_interfaces.trigger_repository) + return await use_case.execute(user=auth) + + +@trigger_router_v1.get("/triggers/{trigger_id}", response_model=GetTriggerV1Response) +async def get_trigger( + trigger_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> GetTriggerV1Response: + """ + Describes the trigger with the given ID + """ + logger.info(f"GET /triggers/{trigger_id} for {auth}") + try: + use_case = GetTriggerUseCase(trigger_repository=external_interfaces.trigger_repository) + return await use_case.execute(user=auth, trigger_id=trigger_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException(status_code=404, detail=f"Trigger {trigger_id} was not found.") from exc + + +@trigger_router_v1.put("/triggers/{trigger_id}", response_model=UpdateTriggerV1Response) +async def update_trigger( + trigger_id: str, + request: UpdateTriggerV1Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> UpdateTriggerV1Response: + """ + Updates the trigger with the given ID + """ + logger.info(f"PUT /triggers/{trigger_id} with {request} for {auth}") + try: + use_case = UpdateTriggerUseCase( + trigger_repository=external_interfaces.trigger_repository, + cron_job_gateway=external_interfaces.cron_job_gateway, + ) + return await use_case.execute(user=auth, trigger_id=trigger_id, request=request) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException(status_code=404, detail=f"Trigger {trigger_id} was not found.") from exc + except CronSyntaxException as exc: + raise HTTPException( + status_code=400, + detail=f"The user specified an invalid value for cron_schedule: {exc}", + ) + + +@trigger_router_v1.delete("/triggers/{trigger_id}", response_model=DeleteTriggerV1Response) +async def delete_trigger( + trigger_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> DeleteTriggerV1Response: + """ + Deletes the trigger with the given ID + """ + logger.info(f"DELETE /triggers/{trigger_id} for {auth}") + try: + use_case = DeleteTriggerUseCase( + trigger_repository=external_interfaces.trigger_repository, + cron_job_gateway=external_interfaces.cron_job_gateway, + ) + return await use_case.execute(user=auth, trigger_id=trigger_id) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException(status_code=404, detail=f"Trigger {trigger_id} was not found.") from exc diff --git a/model-engine/model_engine_server/api/v2/__init__.py b/model-engine/model_engine_server/api/v2/__init__.py new file mode 100644 index 00000000..abb0fdec --- /dev/null +++ b/model-engine/model_engine_server/api/v2/__init__.py @@ -0,0 +1,14 @@ +from typing import Sequence + +from fastapi import APIRouter + +from .batch_completion import batch_completions_router_v2 +from .chat_completion import chat_router_v2 +from .completion import completion_router_v2 + +llm_router_v2 = APIRouter(prefix="/v2") +llm_router_v2.include_router(batch_completions_router_v2) +llm_router_v2.include_router(chat_router_v2) +llm_router_v2.include_router(completion_router_v2) + +__all__: Sequence[str] = ("llm_router_v2",) diff --git a/model-engine/model_engine_server/api/v2/batch_completion.py b/model-engine/model_engine_server/api/v2/batch_completion.py new file mode 100644 index 00000000..78a8bfdf --- /dev/null +++ b/model-engine/model_engine_server/api/v2/batch_completion.py @@ -0,0 +1,160 @@ +from fastapi import APIRouter, Depends, HTTPException +from model_engine_server.api.dependencies import ( + ExternalInterfaces, + get_external_interfaces, + get_external_interfaces_read_only, + verify_authentication, +) +from model_engine_server.common.dtos.llms.batch_completion import ( + CancelBatchCompletionsV2Response, + CreateBatchCompletionsV2Request, + CreateBatchCompletionsV2Response, + GetBatchCompletionV2Response, + UpdateBatchCompletionsV2Request, + UpdateBatchCompletionsV2Response, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) +from model_engine_server.domain.exceptions import ( + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, +) +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MetricMetadata +from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( + CancelBatchCompletionV2UseCase, + CreateBatchCompletionsV2UseCase, + GetBatchCompletionV2UseCase, + UpdateBatchCompletionV2UseCase, +) + +from .common import get_metric_metadata, record_route_call + +logger = make_logger(logger_name()) + + +batch_completions_router_v2 = APIRouter( + prefix="/batch-completions", dependencies=[Depends(record_route_call)] +) + + +@batch_completions_router_v2.post("", response_model=CreateBatchCompletionsV2Response) +async def batch_completions( + request: CreateBatchCompletionsV2Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), +) -> CreateBatchCompletionsV2Response: # pragma: no cover + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.info(f"POST /v2/batch-completions {request} for {auth}") + try: + use_case = CreateBatchCompletionsV2UseCase( + llm_batch_completions_service=external_interfaces.llm_batch_completions_service, + llm_artifact_gateway=external_interfaces.llm_artifact_gateway, + ) + + return await use_case.execute(request, user=auth) + except ObjectHasInvalidValueException as exc: # pragma: no cover + raise HTTPException(status_code=400, detail=str(exc)) + except ObjectNotFoundException as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + + except Exception as exc: + logger.exception(f"Error processing request {request} for {auth}") + raise HTTPException( + status_code=500, + detail=f"Internal server error. request_id: {request_id}", + ) from exc + + +@batch_completions_router_v2.get( + "/{batch_completion_id}", + response_model=GetBatchCompletionV2Response, +) +async def get_batch_completion( + batch_completion_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), +) -> GetBatchCompletionV2Response: + logger.info(f"GET /v2/batch-completions/{batch_completion_id} for {auth}") + try: + use_case = GetBatchCompletionV2UseCase( + llm_batch_completions_service=external_interfaces.llm_batch_completions_service, + ) + return await use_case.execute(batch_completion_id=batch_completion_id, user=auth) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + + +@batch_completions_router_v2.post( + "/{batch_completion_id}", + response_model=UpdateBatchCompletionsV2Response, +) +async def update_batch_completion( + batch_completion_id: str, + request: UpdateBatchCompletionsV2Request, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> UpdateBatchCompletionsV2Response: # pragma: no cover + logger.info(f"POST /v2/batch-completions/{batch_completion_id} {request} for {auth}") + try: + use_case = UpdateBatchCompletionV2UseCase( + llm_batch_completions_service=external_interfaces.llm_batch_completions_service, + ) + return await use_case.execute( + batch_completion_id=batch_completion_id, request=request, user=auth + ) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + except Exception as exc: + logger.exception(f"Error processing request {request} for {auth}", exc_info=exc) + raise HTTPException( + status_code=500, + detail="Internal server error", + ) from exc + + +@batch_completions_router_v2.post( + "/{batch_completion_id}/actions/cancel", + response_model=CancelBatchCompletionsV2Response, +) +async def cancel_batch_completion( + batch_completion_id: str, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), +) -> CancelBatchCompletionsV2Response: # pragma: no cover + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.info(f"POST /v2/batch-completions/{batch_completion_id}/actions/cancel for {auth}") + try: + use_case = CancelBatchCompletionV2UseCase( + llm_batch_completions_service=external_interfaces.llm_batch_completions_service, + ) + return await use_case.execute(batch_completion_id=batch_completion_id, user=auth) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + except Exception as exc: + logger.exception( + f"Error canceling batch completions {batch_completion_id} for {auth}", + exc_info=exc, + ) + raise HTTPException( + status_code=500, + detail=f"Internal server error. request_id: {request_id}", + ) from exc diff --git a/model-engine/model_engine_server/api/v2/chat_completion.py b/model-engine/model_engine_server/api/v2/chat_completion.py new file mode 100644 index 00000000..614f159d --- /dev/null +++ b/model-engine/model_engine_server/api/v2/chat_completion.py @@ -0,0 +1,286 @@ +import traceback +from datetime import datetime +from typing import Any + +import pytz +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException +from model_engine_server.api.dependencies import ( + ExternalInterfaces, + get_external_interfaces_read_only, + verify_authentication, +) +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.llms import ( + ChatCompletionV2Request, + ChatCompletionV2Response, + ChatCompletionV2ResponseItem, + ChatCompletionV2StreamErrorChunk, + StreamError, + StreamErrorContent, + TokenUsage, +) +from model_engine_server.common.dtos.llms.chat_completion import ChatCompletionV2StreamSuccessChunk +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) +from model_engine_server.core.utils.timer import timer +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, + EndpointUnsupportedRequestException, + InvalidRequestException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, + UpstreamServiceError, +) +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MetricMetadata +from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( + ChatCompletionStreamV2UseCase, + ChatCompletionSyncV2UseCase, +) +from sse_starlette import EventSourceResponse + +from .common import get_metric_metadata, record_route_call + +logger = make_logger(logger_name()) + +chat_router_v2 = APIRouter(dependencies=[Depends(record_route_call)]) + + +def handle_streaming_exception( + e: Exception, + code: int, + message: str, +): # pragma: no cover + tb_str = traceback.format_exception(e) + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + structured_log = { + "error": message, + "request_id": str(request_id), + "traceback": "".join(tb_str), + } + logger.error("Exception: %s", structured_log) + return { + "data": ChatCompletionV2StreamErrorChunk( + request_id=str(request_id), + error=StreamError( + status_code=code, + content=StreamErrorContent( + error=message, + timestamp=timestamp, + ), + ), + ).model_dump_json(exclude_none=True) + } + + +async def handle_stream_request( + external_interfaces: ExternalInterfaces, + background_tasks: BackgroundTasks, + request: ChatCompletionV2Request, + auth: User, + model_endpoint_name: str, + metric_metadata: MetricMetadata, +): # pragma: no cover + use_case = ChatCompletionStreamV2UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + tokenizer_repository=external_interfaces.tokenizer_repository, + ) + + with timer() as use_case_timer: + try: + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + + # We fetch the first response to check if upstream request was successful + # If it was not, this will raise the corresponding HTTPException + # If it was, we will proceed to the event generator + first_message: ChatCompletionV2StreamSuccessChunk = await response.__anext__() + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + except ( + EndpointUnsupportedInferenceTypeException, + EndpointUnsupportedRequestException, + ) as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + raise HTTPException( + status_code=500, + detail="Internal error occurred. Our team has been notified.", + ) from exc + + async def event_generator(timer: timer = use_case_timer): + try: + ttft = None + message = None + yield {"data": first_message.model_dump_json(exclude_none=True)} + + async for message in response: + if ttft is None: + ttft = timer.lap() + # if ttft is None and message.startswith("data"): + # ttft = timer.lap() + yield {"data": message.model_dump_json(exclude_none=True)} + + if message: + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=( + message.usage.prompt_tokens if message.usage else None + ), + num_completion_tokens=( + message.usage.completion_tokens if message.usage else None + ), + total_duration=timer.duration, + ), + metric_metadata, + ) + + # The following two exceptions are only raised after streaming begins, so we wrap the exception within a Response object + except InvalidRequestException as exc: + yield handle_streaming_exception(exc, 400, str(exc)) + except UpstreamServiceError as exc: + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.exception( + f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" + ) + yield handle_streaming_exception( + exc, + 500, + f"Upstream service error for request_id {request_id}", + ) + except Exception as exc: + yield handle_streaming_exception( + exc, 500, "Internal error occurred. Our team has been notified." + ) + + return EventSourceResponse(event_generator(timer=use_case_timer)) + + +async def handle_sync_request( + external_interfaces: ExternalInterfaces, + request: ChatCompletionV2Request, + background_tasks: BackgroundTasks, + auth: User, + model_endpoint_name: str, + metric_metadata: MetricMetadata, +): + try: + use_case = ChatCompletionSyncV2UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + tokenizer_repository=external_interfaces.tokenizer_repository, + ) + with timer() as use_case_timer: + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=(response.usage.prompt_tokens if response.usage else None), + num_completion_tokens=( + response.usage.completion_tokens if response.usage else None + ), + total_duration=use_case_timer.duration, + ), + metric_metadata, + ) + return response + except UpstreamServiceError as exc: + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.exception( + f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" + ) + raise HTTPException( + status_code=500, + detail=f"Upstream service error for request_id {request_id}", + ) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + if isinstance(exc, ObjectNotAuthorizedException): # pragma: no cover + logger.info( + f"POST /completions-sync to endpoint {model_endpoint_name} for {auth} failed with authz error {exc.args}" + ) + + raise HTTPException( + status_code=404, + detail="The specified endpoint could not be found.", + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=to_error_details(exc)) + except InvalidRequestException as exc: + raise HTTPException(status_code=400, detail=to_error_details(exc)) + except EndpointUnsupportedRequestException as exc: + raise HTTPException( + status_code=400, + detail=f"Endpoint does not support request: {str(exc)}", + ) from exc + except EndpointUnsupportedInferenceTypeException as exc: + raise HTTPException( + status_code=400, + detail=f"Unsupported inference type: {str(exc)}", + ) from exc + + +def to_error_details(exc: Exception) -> Any: + if not exc.args or len(exc.args) == 0: + return str(exc) + if len(exc.args) == 1: + return exc.args[0] + else: + return exc.args + + +@chat_router_v2.post("/chat/completions", response_model=ChatCompletionV2ResponseItem) +async def chat_completion( + request: ChatCompletionV2Request, + background_tasks: BackgroundTasks, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), +) -> ChatCompletionV2Response: # pragma: no cover + model_endpoint_name = request.model + if hmi_config.sensitive_log_mode: + logger.info( + f"POST /v2/chat/completion ({('stream' if request.stream else 'sync')}) to endpoint {model_endpoint_name} for {auth}" + ) + else: + logger.info( + f"POST /v2/chat/completion ({('stream' if request.stream else 'sync')}) with {request} to endpoint {model_endpoint_name} for {auth}" + ) + + if request.stream: + return await handle_stream_request( + external_interfaces=external_interfaces, + background_tasks=background_tasks, + request=request, + auth=auth, + model_endpoint_name=model_endpoint_name, + metric_metadata=metric_metadata, + ) + else: + return await handle_sync_request( + external_interfaces=external_interfaces, + background_tasks=background_tasks, + request=request, + auth=auth, + model_endpoint_name=model_endpoint_name, + metric_metadata=metric_metadata, + ) diff --git a/model-engine/model_engine_server/api/v2/common.py b/model-engine/model_engine_server/api/v2/common.py new file mode 100644 index 00000000..d651eb4b --- /dev/null +++ b/model-engine/model_engine_server/api/v2/common.py @@ -0,0 +1,37 @@ +from fastapi import Depends, Request +from model_engine_server.api.dependencies import ( + ExternalInterfaces, + get_external_interfaces_read_only, + verify_authentication, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MetricMetadata + + +def format_request_route(request: Request) -> str: + url_path = request.url.path + for path_param in request.path_params: + url_path = url_path.replace(request.path_params[path_param], f":{path_param}") + return f"{request.method}_{url_path}".lower() + + +async def get_metric_metadata( + request: Request, + auth: User = Depends(verify_authentication), +) -> MetricMetadata: + # note that this is ok because request will cache the body + body = await request.json() + model_name = body.get("model", None) + if not model_name: + # get model name from batch completion request + model_name = body.get("model_config", {}).get("model", None) + + return MetricMetadata(user=auth, model_name=model_name) + + +async def record_route_call( + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + route: str = Depends(format_request_route), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), +): + external_interfaces.monitoring_metrics_gateway.emit_route_call_metric(route, metric_metadata) diff --git a/model-engine/model_engine_server/api/v2/completion.py b/model-engine/model_engine_server/api/v2/completion.py new file mode 100644 index 00000000..ed529fe3 --- /dev/null +++ b/model-engine/model_engine_server/api/v2/completion.py @@ -0,0 +1,285 @@ +import traceback +from datetime import datetime +from typing import Any + +import pytz +from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException +from model_engine_server.api.dependencies import ( + ExternalInterfaces, + get_external_interfaces_read_only, + verify_authentication, +) +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.llms import ( + CompletionV2Request, + CompletionV2Response, + CompletionV2ResponseItem, + CompletionV2StreamErrorChunk, + StreamError, + StreamErrorContent, + TokenUsage, +) +from model_engine_server.common.dtos.llms.completion import CompletionV2StreamSuccessChunk +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) +from model_engine_server.core.utils.timer import timer +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, + EndpointUnsupportedRequestException, + InvalidRequestException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, + UpstreamServiceError, +) +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MetricMetadata +from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( + CompletionStreamV2UseCase, + CompletionSyncV2UseCase, +) +from sse_starlette import EventSourceResponse + +from .common import get_metric_metadata, record_route_call + +logger = make_logger(logger_name()) + +completion_router_v2 = APIRouter(dependencies=[Depends(record_route_call)]) + + +def handle_streaming_exception( + e: Exception, + code: int, + message: str, +): # pragma: no cover + tb_str = traceback.format_exception(e) + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + timestamp = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %H:%M:%S %Z") + structured_log = { + "error": message, + "request_id": str(request_id), + "traceback": "".join(tb_str), + } + logger.error("Exception: %s", structured_log) + return { + "data": CompletionV2StreamErrorChunk( + request_id=str(request_id), + error=StreamError( + status_code=code, + content=StreamErrorContent( + error=message, + timestamp=timestamp, + ), + ), + ).model_dump_json(exclude_none=True) + } + + +async def handle_stream_request( + external_interfaces: ExternalInterfaces, + background_tasks: BackgroundTasks, + request: CompletionV2Request, + auth: User, + model_endpoint_name: str, + metric_metadata: MetricMetadata, +): # pragma: no cover + use_case = CompletionStreamV2UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + tokenizer_repository=external_interfaces.tokenizer_repository, + ) + + with timer() as use_case_timer: + try: + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + + # We fetch the first response to check if upstream request was successful + # If it was not, this will raise the corresponding HTTPException + # If it was, we will proceed to the event generator + first_message: CompletionV2StreamSuccessChunk = await response.__anext__() + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + raise HTTPException( + status_code=404, + detail=str(exc), + ) from exc + except ( + EndpointUnsupportedInferenceTypeException, + EndpointUnsupportedRequestException, + ) as exc: + raise HTTPException( + status_code=400, + detail=str(exc), + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=str(exc)) from exc + except Exception as exc: + raise HTTPException( + status_code=500, + detail="Internal error occurred. Our team has been notified.", + ) from exc + + async def event_generator(timer: timer = use_case_timer): + try: + ttft = None + message = None + yield {"data": first_message.model_dump_json(exclude_none=True)} + async for message in response: + if ttft is None: + ttft = timer.lap() + # if ttft is None and message.startswith("data"): + # ttft = timer.lap() + yield {"data": message.model_dump_json(exclude_none=True)} + + if message: + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=( + message.usage.prompt_tokens if message.usage else None + ), + num_completion_tokens=( + message.usage.completion_tokens if message.usage else None + ), + total_duration=timer.duration, + ), + metric_metadata, + ) + + # The following two exceptions are only raised after streaming begins, so we wrap the exception within a Response object + except InvalidRequestException as exc: + yield handle_streaming_exception(exc, 400, str(exc)) + except UpstreamServiceError as exc: + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.exception( + f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" + ) + yield handle_streaming_exception( + exc, + 500, + f"Upstream service error for request_id {request_id}", + ) + except Exception as exc: + yield handle_streaming_exception( + exc, 500, "Internal error occurred. Our team has been notified." + ) + + return EventSourceResponse(event_generator()) + + +async def handle_sync_request( + external_interfaces: ExternalInterfaces, + request: CompletionV2Request, + background_tasks: BackgroundTasks, + auth: User, + model_endpoint_name: str, + metric_metadata: MetricMetadata, +): + try: + use_case = CompletionSyncV2UseCase( + model_endpoint_service=external_interfaces.model_endpoint_service, + llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, + tokenizer_repository=external_interfaces.tokenizer_repository, + ) + with timer() as use_case_timer: + response = await use_case.execute( + user=auth, model_endpoint_name=model_endpoint_name, request=request + ) + + background_tasks.add_task( + external_interfaces.monitoring_metrics_gateway.emit_token_count_metrics, + TokenUsage( + num_prompt_tokens=(response.usage.prompt_tokens if response.usage else None), + num_completion_tokens=( + response.usage.completion_tokens if response.usage else None + ), + total_duration=use_case_timer.duration, + ), + metric_metadata, + ) + return response + except UpstreamServiceError as exc: + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + logger.exception( + f"Upstream service error for request {request_id}. Error detail: {str(exc.content)}" + ) + raise HTTPException( + status_code=500, + detail=f"Upstream service error for request_id {request_id}", + ) + except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: + if isinstance(exc, ObjectNotAuthorizedException): # pragma: no cover + logger.info( + f"POST /completions-sync to endpoint {model_endpoint_name} for {auth} failed with authz error {exc.args}" + ) + + raise HTTPException( + status_code=404, + detail="The specified endpoint could not be found.", + ) from exc + except ObjectHasInvalidValueException as exc: + raise HTTPException(status_code=400, detail=to_error_details(exc)) + except InvalidRequestException as exc: + raise HTTPException(status_code=400, detail=to_error_details(exc)) + except EndpointUnsupportedRequestException as exc: + raise HTTPException( + status_code=400, + detail=f"Endpoint does not support request: {str(exc)}", + ) from exc + except EndpointUnsupportedInferenceTypeException as exc: + raise HTTPException( + status_code=400, + detail=f"Unsupported inference type: {str(exc)}", + ) from exc + + +def to_error_details(exc: Exception) -> Any: + if not exc.args or len(exc.args) == 0: + return str(exc) + if len(exc.args) == 1: + return exc.args[0] + else: + return exc.args + + +@completion_router_v2.post("/completions", response_model=CompletionV2ResponseItem) +async def completion( + request: CompletionV2Request, + background_tasks: BackgroundTasks, + auth: User = Depends(verify_authentication), + external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), + metric_metadata: MetricMetadata = Depends(get_metric_metadata), +) -> CompletionV2Response: # pragma: no cover + model_endpoint_name = request.model + if hmi_config.sensitive_log_mode: + logger.info( + f"POST /v2/completion ({('stream' if request.stream else 'sync')}) to endpoint {model_endpoint_name} for {auth}" + ) + else: + logger.info( + f"POST /v2/completion ({('stream' if request.stream else 'sync')}) with {request} to endpoint {model_endpoint_name} for {auth}" + ) + + if request.stream: + return await handle_stream_request( + external_interfaces=external_interfaces, + background_tasks=background_tasks, + request=request, + auth=auth, + model_endpoint_name=model_endpoint_name, + metric_metadata=metric_metadata, + ) + else: + return await handle_sync_request( + external_interfaces=external_interfaces, + background_tasks=background_tasks, + request=request, + auth=auth, + model_endpoint_name=model_endpoint_name, + metric_metadata=metric_metadata, + ) diff --git a/model-engine/model_engine_server/api/worker.py b/model-engine/model_engine_server/api/worker.py new file mode 100644 index 00000000..289640c8 --- /dev/null +++ b/model-engine/model_engine_server/api/worker.py @@ -0,0 +1,14 @@ +from uvicorn.workers import UvicornWorker + +# Gunicorn returns 503 instead of 429 when concurrency exceeds the limit +# We'll autoscale at target concurrency of a much lower number (around 50), and this just makes sure we don't 503 with bursty traffic +# We set this very high since model_engine_server/api/app.py sets a lower per-pod concurrency at which we start returning 429s +CONCURRENCY_LIMIT = 10000 + + +class LaunchWorker(UvicornWorker): + """Overrides the configuration of the Uvicorn Worker.""" + + # uvloop and httptools are both faster than their alternatives, but they are not compatible + # with Windows or PyPy. + CONFIG_KWARGS = {"loop": "uvloop", "http": "httptools", "limit_concurrency": CONCURRENCY_LIMIT} diff --git a/server/llm_engine_server/common/__init__.py b/model-engine/model_engine_server/common/__init__.py similarity index 100% rename from server/llm_engine_server/common/__init__.py rename to model-engine/model_engine_server/common/__init__.py diff --git a/server/llm_engine_server/infra/gateways/aiohttp_sse_client.py b/model-engine/model_engine_server/common/aiohttp_sse_client.py similarity index 100% rename from server/llm_engine_server/infra/gateways/aiohttp_sse_client.py rename to model-engine/model_engine_server/common/aiohttp_sse_client.py diff --git a/model-engine/model_engine_server/common/concurrency_limiter.py b/model-engine/model_engine_server/common/concurrency_limiter.py new file mode 100644 index 00000000..b4e10c81 --- /dev/null +++ b/model-engine/model_engine_server/common/concurrency_limiter.py @@ -0,0 +1,36 @@ +from multiprocessing import BoundedSemaphore +from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType +from typing import Optional + +from fastapi import HTTPException +from model_engine_server.core.loggers import logger_name, make_logger + +logger = make_logger(logger_name()) + + +class MultiprocessingConcurrencyLimiter: + def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool): + self.concurrency = concurrency + if concurrency is not None: + if concurrency < 1: + raise ValueError("Concurrency should be at least 1") + self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency) + self.blocking = ( + not fail_on_concurrency_limit + ) # we want to block if we want to queue up requests + else: + self.semaphore = None + self.blocking = False # Unused + + def __enter__(self): + logger.debug("Entering concurrency limiter semaphore") + if self.semaphore and not self.semaphore.acquire(block=self.blocking): + logger.warning(f"Too many requests (max {self.concurrency}), returning 429") + raise HTTPException(status_code=429, detail="Too many requests") + # Just raises an HTTPException. + # __exit__ should not run; otherwise the release() doesn't have an acquire() + + def __exit__(self, type, value, traceback): + logger.debug("Exiting concurrency limiter semaphore") + if self.semaphore: + self.semaphore.release() diff --git a/model-engine/model_engine_server/common/config.py b/model-engine/model_engine_server/common/config.py new file mode 100644 index 00000000..1226d62a --- /dev/null +++ b/model-engine/model_engine_server/common/config.py @@ -0,0 +1,152 @@ +# Keep in line with service_config_{*}.yaml +# This file loads sensitive data that shouldn't make it to inference docker images +# Do not include this file in our inference/endpoint code +import inspect +import os +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Sequence + +import yaml +from azure.identity import DefaultAzureCredential +from model_engine_server.core.aws.secrets import get_key_file +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger + +logger = make_logger(logger_name()) + +__all__: Sequence[str] = ( + "DEFAULT_SERVICE_CONFIG_PATH", + "SERVICE_CONFIG_PATH", + "HostedModelInferenceServiceConfig", + "hmi_config", +) + +DEFAULT_SERVICE_CONFIG_PATH = str( + ( + Path(__file__).absolute().parent.parent.parent + / "service_configs" + / "service_config_circleci.yaml" + ).absolute() +) + +SERVICE_CONFIG_PATH = os.environ.get("DEPLOY_SERVICE_CONFIG_PATH", DEFAULT_SERVICE_CONFIG_PATH) + +redis_cache_expiration_timestamp = None + + +# duplicated from llm/finetune_pipeline +def get_model_cache_directory_name(model_name: str): + """How huggingface maps model names to directory names in their cache for model files. + We adopt this when storing model cache files in s3. + + Args: + model_name (str): Name of the huggingface model + """ + name = "models--" + model_name.replace("/", "--") + return name + + +@dataclass +class HostedModelInferenceServiceConfig: + gateway_namespace: str + endpoint_namespace: str + billing_queue_arn: str + sqs_profile: str + sqs_queue_policy_template: str + sqs_queue_tag_template: str + model_primitive_host: str + cloud_file_llm_fine_tune_repository: str + hf_user_fine_tuned_weights_prefix: str + istio_enabled: bool + dd_trace_enabled: bool + tgi_repository: str + vllm_repository: str + lightllm_repository: str + tensorrt_llm_repository: str + batch_inference_vllm_repository: str + user_inference_base_repository: str + user_inference_pytorch_repository: str + user_inference_tensorflow_repository: str + docker_image_layer_cache_repository: str + sensitive_log_mode: bool + # Exactly one of the following three must be specified + cache_redis_aws_url: Optional[str] = None # also using this to store sync autoscaling metrics + cache_redis_azure_host: Optional[str] = None + cache_redis_aws_secret_name: Optional[str] = ( + None # Not an env var because the redis cache info is already here + ) + + @classmethod + def from_json(cls, json): + return cls(**{k: v for k, v in json.items() if k in inspect.signature(cls).parameters}) + + @classmethod + def from_yaml(cls, yaml_path): + with open(yaml_path, "r") as f: + raw_data = yaml.safe_load(f) + return HostedModelInferenceServiceConfig.from_json(raw_data) + + @property + def cache_redis_url(self) -> str: + if self.cache_redis_aws_url: + assert infra_config().cloud_provider == "aws", "cache_redis_aws_url is only for AWS" + if self.cache_redis_aws_secret_name: + logger.warning( + "Both cache_redis_aws_url and cache_redis_aws_secret_name are set. Using cache_redis_aws_url" + ) + return self.cache_redis_aws_url + elif self.cache_redis_aws_secret_name: + assert ( + infra_config().cloud_provider == "aws" + ), "cache_redis_aws_secret_name is only for AWS" + creds = get_key_file(self.cache_redis_aws_secret_name) # Use default role + return creds["cache-url"] + + assert self.cache_redis_azure_host and infra_config().cloud_provider == "azure" + username = os.getenv("AZURE_OBJECT_ID") + token = DefaultAzureCredential().get_token("https://redis.azure.com/.default") + password = token.token + global redis_cache_expiration_timestamp + redis_cache_expiration_timestamp = token.expires_on + return f"rediss://{username}:{password}@{self.cache_redis_azure_host}" + + @property + def cache_redis_url_expiration_timestamp(self) -> Optional[int]: + global redis_cache_expiration_timestamp + return redis_cache_expiration_timestamp + + @property + def cache_redis_host_port(self) -> str: + # redis://redis.url:6379/ + # -> redis.url:6379 + if "rediss://" in self.cache_redis_url: + return self.cache_redis_url.split("rediss://")[1].split("@")[-1].split("/")[0] + return self.cache_redis_url.split("redis://")[1].split("/")[0] + + @property + def cache_redis_db_index(self) -> int: + # redis://redis.url:6379/ + # -> + try: + return int(self.cache_redis_url.split("/")[-1]) + except ValueError: + return 0 # 0 is the default index used by redis if it's not specified + + +def read_default_config(): + logger.info(f"Using config file path: `{SERVICE_CONFIG_PATH}`") + return HostedModelInferenceServiceConfig.from_yaml(SERVICE_CONFIG_PATH) + + +_hmi_config: Optional[HostedModelInferenceServiceConfig] = None + + +def get_hmi_config() -> HostedModelInferenceServiceConfig: + global _hmi_config + if _hmi_config is None: + _hmi_config = read_default_config() + return _hmi_config + + +hmi_config = get_hmi_config() diff --git a/model-engine/model_engine_server/common/constants.py b/model-engine/model_engine_server/common/constants.py new file mode 100644 index 00000000..00d00d6c --- /dev/null +++ b/model-engine/model_engine_server/common/constants.py @@ -0,0 +1,16 @@ +from pathlib import Path + +BILLING_POST_INFERENCE_HOOK: str = "billing" +CALLBACK_POST_INFERENCE_HOOK: str = "callback" +LOGGING_POST_INFERENCE_HOOK: str = "logging" +SUPPORTED_POST_INFERENCE_HOOKS: list = [ + BILLING_POST_INFERENCE_HOOK, + CALLBACK_POST_INFERENCE_HOOK, + LOGGING_POST_INFERENCE_HOOK, +] +READYZ_FPATH: str = "/tmp/readyz" +DEFAULT_CELERY_TASK_NAME: str = "hosted_model_inference.inference.async_inference.tasks.predict" +LIRA_CELERY_TASK_NAME: str = "ml_serve.celery_service.exec_func" + +PROJECT_ROOT: Path = Path(__file__).parents[2].absolute() +HOSTED_MODEL_INFERENCE_ROOT: Path = PROJECT_ROOT / "model-engine" diff --git a/model-engine/model_engine_server/common/datadog_utils.py b/model-engine/model_engine_server/common/datadog_utils.py new file mode 100644 index 00000000..f5b2844e --- /dev/null +++ b/model-engine/model_engine_server/common/datadog_utils.py @@ -0,0 +1,30 @@ +from typing import Optional + +from ddtrace import tracer + + +def add_trace_request_id(request_id: Optional[str]): + """Adds a custom tag to a given dd trace corresponding to the request id + so that we can filter in Datadog easier + """ + if not request_id: + return + + current_span = tracer.current_span() + if current_span: + current_span.set_tag("launch.request_id", request_id) + + +def add_trace_model_name(model_name: Optional[str]): + """Adds a custom tag to a given dd trace corresponding to the model name + so that we can filter in Datadog easier + + Only use this when the number of model names is small, otherwise it will + blow up the cardinality in Datadog + """ + if not model_name: + return + + current_span = tracer.current_span() + if current_span: + current_span.set_tag("launch.model_name", model_name) diff --git a/server/llm_engine_server/core/aws/__init__.py b/model-engine/model_engine_server/common/dtos/__init__.py similarity index 100% rename from server/llm_engine_server/core/aws/__init__.py rename to model-engine/model_engine_server/common/dtos/__init__.py diff --git a/server/llm_engine_server/common/dtos/batch_jobs.py b/model-engine/model_engine_server/common/dtos/batch_jobs.py similarity index 75% rename from server/llm_engine_server/common/dtos/batch_jobs.py rename to model-engine/model_engine_server/common/dtos/batch_jobs.py index e1fc45fa..8d24665e 100644 --- a/server/llm_engine_server/common/dtos/batch_jobs.py +++ b/model-engine/model_engine_server/common/dtos/batch_jobs.py @@ -1,28 +1,30 @@ """ DTOs for the batch job abstraction. """ + from datetime import datetime, timedelta from typing import Any, Collection, Dict, List, Optional -from llm_engine_server.common import dict_not_none -from llm_engine_server.domain.entities import ( +from model_engine_server.common import dict_not_none +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, model_validator +from model_engine_server.domain.entities import ( BatchJobSerializationFormat, BatchJobStatus, CpuSpecificationType, + DockerImageBatchJob, GpuType, StorageSpecificationType, ) -from pydantic import BaseModel, root_validator class CreateBatchJobResourceRequests(BaseModel): - cpus: Optional[CpuSpecificationType] - memory: Optional[StorageSpecificationType] - gpus: Optional[int] - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] - max_workers: Optional[int] - per_worker: Optional[int] + cpus: Optional[CpuSpecificationType] = None + memory: Optional[StorageSpecificationType] = None + gpus: Optional[int] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + max_workers: Optional[int] = None + per_worker: Optional[int] = None class CreateBatchJobV1Request(BaseModel): @@ -40,10 +42,10 @@ class CreateBatchJobV1Response(BaseModel): class GetBatchJobV1Response(BaseModel): status: BatchJobStatus - result: Optional[str] + result: Optional[str] = None duration: timedelta - num_tasks_pending: Optional[int] - num_tasks_completed: Optional[int] + num_tasks_pending: Optional[int] = None + num_tasks_completed: Optional[int] = None class UpdateBatchJobV1Request(BaseModel): @@ -63,9 +65,10 @@ class CreateDockerImageBatchJobResourceRequests(BaseModel): gpus: Optional[int] = None gpu_type: Optional[GpuType] = None storage: Optional[StorageSpecificationType] = None - - class Config: - orm_mode = True + nodes_per_worker: Optional[int] = ( + None # TODO this is used only for inferring hardware, if multinode batch jobs is added we can reuse this field + ) + model_config = ConfigDict(from_attributes=True) @classmethod def merge_requests( @@ -92,7 +95,7 @@ def common_requests( class CreateDockerImageBatchJobV1Request(BaseModel): docker_image_batch_job_bundle_name: Optional[str] = None docker_image_batch_job_bundle_id: Optional[str] = None - job_config: Optional[Dict[str, Any]] + job_config: Optional[Dict[str, Any]] = None # TODO also expose a separate argument to pass an s3file to the job, as opposed to job_config labels: Dict[str, str] # TODO this probably should go in the bundle @@ -100,7 +103,9 @@ class CreateDockerImageBatchJobV1Request(BaseModel): CreateDockerImageBatchJobResourceRequests() ) - @root_validator + override_job_max_runtime_s: Optional[int] = None + + @model_validator(mode="before") def exactly_one_name_or_id(cls, values): bundle_name = values.get("docker_image_batch_job_bundle_name") bundle_id = values.get("docker_image_batch_job_bundle_id") @@ -123,6 +128,10 @@ class GetDockerImageBatchJobV1Response(BaseModel): status: BatchJobStatus +class ListDockerImageBatchJobsV1Response(BaseModel): + jobs: List[DockerImageBatchJob] + + class UpdateDockerImageBatchJobV1Request(BaseModel): cancel: bool @@ -159,16 +168,14 @@ class DockerImageBatchJobBundleV1Response(BaseModel): image_tag: str command: List[str] env: Dict[str, str] - mount_location: Optional[str] - cpus: Optional[str] - memory: Optional[str] - storage: Optional[str] - gpus: Optional[int] - gpu_type: Optional[str] - public: Optional[bool] - - class Config: - orm_mode = True + mount_location: Optional[str] = None + cpus: Optional[str] = None + memory: Optional[str] = None + storage: Optional[str] = None + gpus: Optional[int] = None + gpu_type: Optional[str] = None + public: Optional[bool] = None + model_config = ConfigDict(from_attributes=True) class ListDockerImageBatchJobBundleV1Response(BaseModel): diff --git a/model-engine/model_engine_server/common/dtos/core.py b/model-engine/model_engine_server/common/dtos/core.py new file mode 100644 index 00000000..c8d2ee22 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/core.py @@ -0,0 +1,11 @@ +from pydantic import BeforeValidator, HttpUrl, TypeAdapter +from typing_extensions import Annotated + +# See: https://github.com/pydantic/pydantic/issues/7186 +# pydantic v2 doesn't treat HttpUrl the same way as in v1 which causes various issue +# This is an attempt to make it behave as similar as possible +HttpUrlTypeAdapter = TypeAdapter(HttpUrl) +HttpUrlStr = Annotated[ + str, + BeforeValidator(lambda value: HttpUrlTypeAdapter.validate_python(value) and value), +] diff --git a/server/llm_engine_server/common/dtos/docker_repository.py b/model-engine/model_engine_server/common/dtos/docker_repository.py similarity index 65% rename from server/llm_engine_server/common/dtos/docker_repository.py rename to model-engine/model_engine_server/common/dtos/docker_repository.py index 5548eead..a5ddc1cf 100644 --- a/server/llm_engine_server/common/dtos/docker_repository.py +++ b/model-engine/model_engine_server/common/dtos/docker_repository.py @@ -1,6 +1,6 @@ from typing import Dict, Optional -from pydantic import BaseModel +from model_engine_server.common.pydantic_types import BaseModel class BuildImageRequest(BaseModel): @@ -10,13 +10,14 @@ class BuildImageRequest(BaseModel): base_path: str dockerfile: str base_image: str - requirements_folder: Optional[str] - substitution_args: Optional[Dict[str, str]] + requirements_folder: Optional[str] = None + substitution_args: Optional[Dict[str, str]] = None class BuildImageResponse(BaseModel): status: bool logs: str + job_name: str # TODO: We may want to add a DTO for streaming logs from the docker build to users. diff --git a/server/llm_engine_server/common/dtos/endpoint_builder.py b/model-engine/model_engine_server/common/dtos/endpoint_builder.py similarity index 54% rename from server/llm_engine_server/common/dtos/endpoint_builder.py rename to model-engine/model_engine_server/common/dtos/endpoint_builder.py index 9817fbbc..2f5c5dbc 100644 --- a/server/llm_engine_server/common/dtos/endpoint_builder.py +++ b/model-engine/model_engine_server/common/dtos/endpoint_builder.py @@ -1,14 +1,14 @@ from enum import Enum from typing import Any, Dict, List, Optional -from llm_engine_server.domain.entities import ( +from model_engine_server.common.pydantic_types import BaseModel +from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, GpuType, ModelEndpointRecord, StorageSpecificationType, ) -from pydantic import BaseModel class BuildEndpointRequest(BaseModel): @@ -20,18 +20,20 @@ class BuildEndpointRequest(BaseModel): cpus: CpuSpecificationType gpus: int memory: StorageSpecificationType - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + nodes_per_worker: int = 1 # Multinode support. >1 = multinode. optimize_costs: bool aws_role: str results_s3_bucket: str - child_fn_info: Optional[Dict[str, Any]] # TODO: remove this if we don't need it. - post_inference_hooks: Optional[List[str]] + child_fn_info: Optional[Dict[str, Any]] = None # TODO: remove this if we don't need it. + post_inference_hooks: Optional[List[str]] = None labels: Dict[str, str] + billing_tags: Optional[Dict[str, Any]] = None prewarm: bool = True - high_priority: Optional[bool] - default_callback_url: Optional[str] - default_callback_auth: Optional[CallbackAuth] + high_priority: Optional[bool] = None + default_callback_url: Optional[str] = None + default_callback_auth: Optional[CallbackAuth] = None class BuildEndpointStatus(str, Enum): diff --git a/model-engine/model_engine_server/common/dtos/files.py b/model-engine/model_engine_server/common/dtos/files.py new file mode 100644 index 00000000..8fa6e8a8 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/files.py @@ -0,0 +1,48 @@ +""" +DTOs for Files API. +""" + +from typing import List + +from model_engine_server.common.pydantic_types import BaseModel, Field + + +class UploadFileResponse(BaseModel): + """Response object for uploading a file.""" + + id: str = Field(..., description="ID of the uploaded file.") + """ID of the uploaded file.""" + + +class GetFileResponse(BaseModel): + """Response object for retrieving a file.""" + + id: str = Field(..., description="ID of the requested file.") + """ID of the requested file.""" + filename: str = Field(..., description="File name.") + """File name.""" + size: int = Field(..., description="Length of the file, in characters.") + """Length of the file, in characters.""" + + +class ListFilesResponse(BaseModel): + """Response object for listing files.""" + + files: List[GetFileResponse] = Field(..., description="List of file IDs, names, and sizes.") + """List of file IDs, names, and sizes.""" + + +class DeleteFileResponse(BaseModel): + """Response object for deleting a file.""" + + deleted: bool = Field(..., description="Whether deletion was successful.") + """Whether deletion was successful.""" + + +class GetFileContentResponse(BaseModel): + """Response object for retrieving a file's content.""" + + id: str = Field(..., description="ID of the requested file.") + """ID of the requested file.""" + content: str = Field(..., description="File content.") + """File content.""" diff --git a/model-engine/model_engine_server/common/dtos/llms/__init__.py b/model-engine/model_engine_server/common/dtos/llms/__init__.py new file mode 100644 index 00000000..663be186 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/llms/__init__.py @@ -0,0 +1,9 @@ +""" +DTOs for LLM APIs. +""" + +from .batch_completion import * # noqa: F403 +from .chat_completion import * # noqa: F403 +from .completion import * # noqa: F403 +from .model_endpoints import * # noqa: F403 +from .vllm import * # noqa: F403 diff --git a/model-engine/model_engine_server/common/dtos/llms/batch_completion.py b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py new file mode 100644 index 00000000..9f1eea1e --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/llms/batch_completion.py @@ -0,0 +1,373 @@ +from enum import Enum +from typing import Dict, List, Optional + +from model_engine_server.common.dtos.llms.chat_completion import ( + ChatCompletionV2Request, + ChatCompletionV2SyncResponse, +) +from model_engine_server.common.dtos.llms.completion import ( + CompletionOutput, + CompletionV2Request, + CompletionV2SyncResponse, +) +from model_engine_server.common.dtos.llms.vllm import VLLMEngineAdditionalArgs, VLLMModelConfig +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field +from model_engine_server.domain.entities.common_types import ( + CpuSpecificationType, + StorageSpecificationType, +) +from model_engine_server.domain.entities.gpu_type import GpuType +from typing_extensions import TypeAlias + + +# Common DTOs for batch completions +class ToolConfig(BaseModel): + """ + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + """ + + name: str + """ + Name of the tool to use for the batch inference. + """ + max_iterations: Optional[int] = 10 + """ + Maximum number of iterations to run the tool. + """ + execution_timeout_seconds: Optional[int] = 60 + """ + Maximum runtime of the tool in seconds. + """ + should_retry_on_error: Optional[bool] = True + """ + Whether to retry the tool on error. + """ + + +class BatchCompletionsModelConfig(VLLMModelConfig): + model: str = Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ) + + checkpoint_path: Optional[str] = Field( + default=None, description="Path to the checkpoint to load the model from." + ) + + num_shards: Optional[int] = Field( + default=1, + ge=1, + description=""" +Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config. +System may decide to use a different number than the given value. +""", + ) + + max_context_length: Optional[int] = Field( + default=None, + ge=1, + description="Maximum context length to use for the model. Defaults to the max allowed by the model. Deprecated in favor of max_model_len.", + ) + + seed: Optional[int] = Field(default=None, description="Random seed for the model.") + + response_role: Optional[str] = Field( + default=None, + description="Role of the response in the conversation. Only supported in chat completions.", + ) + + +class BatchCompletionsRequestBase(BaseModel): + input_data_path: Optional[str] = Field( + default=None, + description="Path to the input file. The input file should be a JSON file of type List[CreateBatchCompletionsRequestContent].", + ) + output_data_path: str = Field( + description="Path to the output file. The output file will be a JSON file of type List[CompletionOutput]." + ) + + labels: Dict[str, str] = Field( + default={}, description="Labels to attach to the batch inference job." + ) + + data_parallelism: Optional[int] = Field( + default=1, + ge=1, + le=64, + description="Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference.", + ) + + max_runtime_sec: Optional[int] = Field( + default=24 * 3600, + ge=1, + le=2 * 24 * 3600, + description="Maximum runtime of the batch inference in seconds. Default to one day.", + ) + + priority: Optional[str] = Field( + default=None, + description="Priority of the batch inference job. Default to None.", + ) + + tool_config: Optional[ToolConfig] = Field( + default=None, + description=""" +Configuration for tool use. +NOTE: this config is highly experimental and signature will change significantly in future iterations.""", + ) + + cpus: Optional[CpuSpecificationType] = Field( + default=None, description="CPUs to use for the batch inference." + ) + gpus: Optional[int] = Field( + default=None, description="Number of GPUs to use for the batch inference." + ) + memory: Optional[StorageSpecificationType] = Field( + default=None, description="Amount of memory to use for the batch inference." + ) + gpu_type: Optional[GpuType] = Field( + default=None, description="GPU type to use for the batch inference." + ) + storage: Optional[StorageSpecificationType] = Field( + default=None, description="Storage to use for the batch inference." + ) + nodes_per_worker: Optional[int] = Field( + default=None, description="Number of nodes per worker for the batch inference." + ) + + +# V1 DTOs for batch completions +CompletionV1Output = CompletionOutput + + +class CreateBatchCompletionsV1ModelConfig(BatchCompletionsModelConfig): + labels: Dict[str, str] = Field( + default={}, description="Labels to attach to the batch inference job." + ) + + +class CreateBatchCompletionsV1RequestContent(BaseModel): + prompts: List[str] + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. + """ + + +class CreateBatchCompletionsV1Request(BatchCompletionsRequestBase): + """ + Request object for batch completions. + """ + + content: Optional[CreateBatchCompletionsV1RequestContent] = None + """ + Either `input_data_path` or `content` needs to be provided. + When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. + """ + model_cfg: CreateBatchCompletionsV1ModelConfig = Field(alias="model_config") + """ + Model configuration for the batch inference. Hardware configurations are inferred. + + We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which + reserves model_config as a keyword. + """ + + +class CreateBatchCompletionsV1Response(BaseModel): + job_id: str + + +class FilteredCompletionV2Request(CompletionV2Request): + model: Optional[str] = None # type: ignore[assignment] + stream: Optional[bool] = False + + +class FilteredChatCompletionV2Request(ChatCompletionV2Request): + model: Optional[str] = None # type: ignore[assignment] + stream: Optional[bool] = False + + +# V2 DTOs for batch completions +CompletionRequest: TypeAlias = FilteredCompletionV2Request | FilteredChatCompletionV2Request +CompletionResponse: TypeAlias = CompletionV2SyncResponse | ChatCompletionV2SyncResponse +CreateBatchCompletionsV2RequestContent: TypeAlias = ( + List[FilteredCompletionV2Request] | List[FilteredChatCompletionV2Request] +) + +CreateBatchCompletionsV2ModelConfig: TypeAlias = BatchCompletionsModelConfig +BatchCompletionContent = ( + CreateBatchCompletionsV1RequestContent | CreateBatchCompletionsV2RequestContent +) + + +class CreateBatchCompletionsV2Request(BatchCompletionsRequestBase): + """ + Request object for batch completions. + """ + + content: Optional[BatchCompletionContent] = Field( + default=None, + description=""" +Either `input_data_path` or `content` needs to be provided. +When input_data_path is provided, the input file should be a JSON file of type List[CreateBatchCompletionsRequestContent]. +""", + ) + + # We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which + # reserves model_config as a keyword. + model_cfg: BatchCompletionsModelConfig = Field( + alias="model_config", + description="""Model configuration for the batch inference. Hardware configurations are inferred.""", + ) + + +class BatchCompletionsJobStatus(str, Enum): + Queued = "queued" + Running = "running" + Completed = "completed" + Failed = "failed" + Cancelled = "cancelled" + Unknown = "unknown" + + +class BatchCompletionsJob(BaseModel): + job_id: str + input_data_path: Optional[str] = Field( + default=None, + description="Path to the input file. The input file should be a JSON file of type List[CreateBatchCompletionsRequestContent].", + ) + output_data_path: str = Field( + description="Path to the output file. The output file will be a JSON file of type List[CompletionOutput]." + ) + + # We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which + # reserves model_config as a keyword. + model_cfg: BatchCompletionsModelConfig = Field( + alias="model_config", + description="""Model configuration for the batch inference. Hardware configurations are inferred.""", + ) + + priority: Optional[str] = Field( + default=None, + description="Priority of the batch inference job. Default to None.", + ) + status: BatchCompletionsJobStatus + created_at: str + expires_at: str + completed_at: Optional[str] + metadata: Optional[Dict[str, str]] + + +CreateBatchCompletionsV2Response: TypeAlias = BatchCompletionsJob + + +class UpdateBatchCompletionsV2Request(BaseModel): + job_id: str = Field(description="ID of the batch completions job") + priority: Optional[str] = Field( + default=None, + description="Priority of the batch inference job. Default to None.", + ) + + +class UpdateBatchCompletionsV2Response(BatchCompletionsJob): + success: bool = Field(description="Whether the update was successful") + + +class CancelBatchCompletionsV2Request(BaseModel): + job_id: str = Field(description="ID of the batch completions job") + + +class CancelBatchCompletionsV2Response(BaseModel): + success: bool = Field(description="Whether the cancellation was successful") + + +class ListBatchCompletionV2Response(BaseModel): + jobs: List[BatchCompletionsJob] + + +class GetBatchCompletionV2Response(BaseModel): + job: BatchCompletionsJob + + +class CreateBatchCompletionsEngineRequest(BatchCompletionsRequestBase, VLLMEngineAdditionalArgs): + """ + Internal model for representing request to the inference framework. This contains additional fields that we want + hidden from the DTO exposed to the client. + """ + + model_config = ConfigDict(populate_by_name=True, protected_namespaces=()) + + content: Optional[BatchCompletionContent] = Field( + default=None, + description="Content is a union of the content from v1 and v2 requests.", + ) + + model_cfg: BatchCompletionsModelConfig = Field( + alias="model_config", + description="""Model configuration for the batch inference. Hardware configurations are inferred.""", + ) + + @staticmethod + def from_api_v1( + request: CreateBatchCompletionsV1Request, + ) -> "CreateBatchCompletionsEngineRequest": + return CreateBatchCompletionsEngineRequest( + input_data_path=request.input_data_path, + output_data_path=request.output_data_path, + content=request.content, + model_config=request.model_cfg, + model_cfg=request.model_cfg, + data_parallelism=request.data_parallelism, + max_runtime_sec=request.max_runtime_sec, + tool_config=request.tool_config, + labels=request.model_cfg.labels, + priority=request.priority, + ) + + @staticmethod + def from_api_v2( + request: CreateBatchCompletionsV2Request, + ) -> "CreateBatchCompletionsEngineRequest": + return CreateBatchCompletionsEngineRequest( + input_data_path=request.input_data_path, + output_data_path=request.output_data_path, + content=request.content, + model_config=request.model_cfg, + model_cfg=request.model_cfg, + data_parallelism=request.data_parallelism, + max_runtime_sec=request.max_runtime_sec, + labels=request.labels, + priority=request.priority, + ) diff --git a/model-engine/model_engine_server/common/dtos/llms/chat_completion.py b/model-engine/model_engine_server/common/dtos/llms/chat_completion.py new file mode 100644 index 00000000..a5f89394 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/llms/chat_completion.py @@ -0,0 +1,55 @@ +from typing import Optional + +from model_engine_server.common.dtos.llms.completion import StreamError +from model_engine_server.common.dtos.llms.vllm import VLLMChatCompletionAdditionalParams +from model_engine_server.common.pydantic_types import BaseModel, Field +from model_engine_server.common.types.gen.openai import ( + CreateChatCompletionRequest, + CreateChatCompletionResponse, + CreateChatCompletionStreamResponse, +) +from sse_starlette import EventSourceResponse +from typing_extensions import Annotated, TypeAlias + +# Fields that are a part of OpenAI spec but are not supported by model engine +UNSUPPORTED_FIELDS = ["service_tier"] + + +class ChatCompletionV2Request(CreateChatCompletionRequest, VLLMChatCompletionAdditionalParams): + model: Annotated[ + str, + Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ), + ] + + stream: Annotated[ + Optional[bool], + Field( + False, + description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ), + ] + + +ChatCompletionV2SyncResponse: TypeAlias = CreateChatCompletionResponse +ChatCompletionV2StreamSuccessChunk: TypeAlias = CreateChatCompletionStreamResponse + + +class ChatCompletionV2StreamErrorChunk(BaseModel): + error: StreamError + + +ChatCompletionV2Chunk: TypeAlias = ( + ChatCompletionV2StreamSuccessChunk | ChatCompletionV2StreamErrorChunk +) +ChatCompletionV2StreamResponse: TypeAlias = ( + EventSourceResponse # EventSourceResponse[ChatCompletionV2Chunk] +) + +ChatCompletionV2Response: TypeAlias = ChatCompletionV2SyncResponse | ChatCompletionV2StreamResponse + +# This is a version of ChatCompletionV2Response that is used by pydantic to determine the response model +# Since EventSourceResponse isn't a pydantic model, we need to use a Union of the two response types +ChatCompletionV2ResponseItem: TypeAlias = ChatCompletionV2SyncResponse | ChatCompletionV2Chunk diff --git a/model-engine/model_engine_server/common/dtos/llms/completion.py b/model-engine/model_engine_server/common/dtos/llms/completion.py new file mode 100644 index 00000000..44ae72db --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/llms/completion.py @@ -0,0 +1,342 @@ +from typing import Any, Dict, List, Optional, TypeAlias + +from model_engine_server.common.dtos.llms.vllm import VLLMCompletionAdditionalParams +from model_engine_server.common.pydantic_types import BaseModel, Field +from model_engine_server.common.types.gen.openai import ( + CreateCompletionRequest, + CreateCompletionResponse, +) +from sse_starlette import EventSourceResponse +from typing_extensions import Annotated + +# Fields that are a part of OpenAI spec but are not supported by model engine +UNSUPPORTED_FIELDS = ["service_tier"] + + +class CompletionSyncV1Request(BaseModel): + """ + Request object for a synchronous prompt completion task. + """ + + prompt: str + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + include_stop_str_in_output: Optional[bool] = None + """ + Whether to include the stop strings in output text. + """ + guided_json: Optional[Dict[str, Any]] = None + """ + JSON schema for guided decoding. Only supported in vllm. + """ + guided_regex: Optional[str] = None + """ + Regex for guided decoding. Only supported in vllm. + """ + guided_choice: Optional[List[str]] = None + """ + Choices for guided decoding. Only supported in vllm. + """ + guided_grammar: Optional[str] = None + """ + Context-free grammar for guided decoding. Only supported in vllm. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. Only supported in vllm. + """ + + +class TokenOutput(BaseModel): + """ + Detailed token information. + """ + + token: str + """ + The token text. + """ + + log_prob: float + """ + The log probability of the token. + """ + + +class CompletionOutput(BaseModel): + """ + Represents the output of a completion request to a model. + """ + + text: str + """The text of the completion.""" + + # We're not guaranteed to have `num_prompt_tokens` in the response in all cases, so to be safe, set a default. + num_prompt_tokens: Optional[int] = None + """Number of tokens in the prompt.""" + + num_completion_tokens: int + """Number of tokens in the completion.""" + + tokens: Optional[List[TokenOutput]] = None + """Detailed token information.""" + + +class CompletionSyncV1Response(BaseModel): + """ + Response object for a synchronous prompt completion. + """ + + request_id: Optional[str] = None + """The unique ID of the corresponding Completion request. This `request_id` is generated on the server, and all logs + associated with the request are grouped by the `request_id`, which allows for easier troubleshooting of errors as + follows: + + * When running the *Scale-hosted* LLM Engine, please provide the `request_id` in any bug reports. + * When running the *self-hosted* LLM Engine, the `request_id` serves as a trace ID in your observability + provider.""" + + output: Optional[CompletionOutput] = None + """Completion output.""" + + +class CompletionStreamV1Request(BaseModel): + """ + Request object for a stream prompt completion task. + """ + + prompt: str + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + include_stop_str_in_output: Optional[bool] = None + """ + Whether to include the stop strings in output text. + """ + guided_json: Optional[Dict[str, Any]] = None + """ + JSON schema for guided decoding. Only supported in vllm. + """ + guided_regex: Optional[str] = None + """ + Regex for guided decoding. Only supported in vllm. + """ + guided_choice: Optional[List[str]] = None + """ + Choices for guided decoding. Only supported in vllm. + """ + guided_grammar: Optional[str] = None + """ + Context-free grammar for guided decoding. Only supported in vllm. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. Only supported in vllm. + """ + + +class CompletionStreamOutput(BaseModel): + text: str + """The text of the completion.""" + + finished: bool + """Whether the completion is finished.""" + + # We're not guaranteed to have `num_prompt_tokens` in the response in all cases, so to be safe, set a default. + num_prompt_tokens: Optional[int] = None + """Number of tokens in the prompt.""" + + num_completion_tokens: Optional[int] = None + """Number of tokens in the completion.""" + + token: Optional[TokenOutput] = None + """Detailed token information.""" + + +class StreamErrorContent(BaseModel): + error: str + """Error message.""" + timestamp: str + """Timestamp of the error.""" + + +class StreamError(BaseModel): + """ + Error object for a stream prompt completion task. + """ + + status_code: int + """The HTTP status code of the error.""" + content: StreamErrorContent + """The error content.""" + + +class CompletionStreamV1Response(BaseModel): + """Error of the response (if any).""" + + """ + Response object for a stream prompt completion task. + """ + + request_id: Optional[str] + """The unique ID of the corresponding Completion request. This `request_id` is generated on the server, and all logs + associated with the request are grouped by the `request_id`, which allows for easier troubleshooting of errors as + follows: + + * When running the *Scale-hosted* LLM Engine, please provide the `request_id` in any bug reports. + * When running the *self-hosted* LLM Engine, the `request_id` serves as a trace ID in your observability + provider.""" + + output: Optional[CompletionStreamOutput] = None + """Completion output.""" + + error: Optional[StreamError] = None + """Error of the response (if any).""" + + +class TokenUsage(BaseModel): + """ + Token usage for a prompt completion task. + """ + + num_prompt_tokens: Optional[int] = 0 + num_completion_tokens: Optional[int] = 0 + total_duration: Optional[float] = None + """Includes time spent waiting for the model to be ready.""" + + time_to_first_token: Optional[float] = None # Only for streaming requests + + @property + def num_total_tokens(self) -> int: + return (self.num_prompt_tokens or 0) + (self.num_completion_tokens or 0) + + @property + def total_tokens_per_second(self) -> float: + return ( + self.num_total_tokens / self.total_duration + if self.total_duration and self.total_duration > 0 + else 0.0 + ) + + @property + def inter_token_latency(self) -> Optional[float]: # Only for streaming requests + # Note: we calculate a single inter-token latency for the entire request. + # Calculating latency between each token seems a bit heavyweight, although we can do this if we wanted + if ( + self.time_to_first_token is None + or self.num_completion_tokens is None + or self.total_duration is None + ): + return None + if self.num_completion_tokens < 2: + return None + return (self.total_duration - self.time_to_first_token) / (self.num_completion_tokens - 1) + + +class CompletionV2Request(CreateCompletionRequest, VLLMCompletionAdditionalParams): + model: Annotated[ + str, + Field( + description="ID of the model to use.", + examples=["mixtral-8x7b-instruct"], + ), + ] + + stream: Annotated[ + Optional[bool], + Field( + False, + description="If set, partial message deltas will be sent. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n", + ), + ] + + top_k: Annotated[ + Optional[int], + Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ), + ] + + include_stop_str_in_output: Annotated[ + Optional[bool], + Field(None, description="Whether to include the stop strings in output text."), + ] + + +CompletionV2SyncResponse: TypeAlias = CreateCompletionResponse +CompletionV2StreamSuccessChunk: TypeAlias = CreateCompletionResponse + + +class CompletionV2StreamErrorChunk(BaseModel): + error: StreamError + + +CompletionV2StreamChunk: TypeAlias = CompletionV2StreamSuccessChunk | CompletionV2StreamErrorChunk +CompletionV2StreamResponse: TypeAlias = ( + EventSourceResponse # EventSourceResponse[CompletionV2StreamChunk] +) + +CompletionV2Response: TypeAlias = CompletionV2SyncResponse | CompletionV2StreamResponse + +# This is a version of CompletionV2Response that is used by pydantic to determine the response model +# Since EventSourceResponse isn't a pydantic model, we need to use a Union of the two response types +CompletionV2ResponseItem: TypeAlias = CompletionV2SyncResponse | CompletionV2StreamChunk diff --git a/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py b/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py new file mode 100644 index 00000000..71cf4e69 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/llms/model_endpoints.py @@ -0,0 +1,225 @@ +""" +DTOs for LLM APIs. + +""" + +from typing import Any, Dict, List, Optional + +from model_engine_server.common.dtos.core import HttpUrlStr +from model_engine_server.common.dtos.llms.vllm import VLLMEndpointAdditionalArgs +from model_engine_server.common.dtos.model_endpoints import ( + CpuSpecificationType, + GetModelEndpointV1Response, + GpuType, + ModelEndpointType, + StorageSpecificationType, +) +from model_engine_server.common.pydantic_types import BaseModel, Field +from model_engine_server.domain.entities import ( + BatchJobStatus, + CallbackAuth, + FineTuneHparamValueType, + LLMFineTuneEvent, + LLMInferenceFramework, + LLMSource, + ModelEndpointStatus, + Quantization, +) + + +class CreateLLMModelEndpointV1Request(VLLMEndpointAdditionalArgs, BaseModel): + name: str + + # LLM specific fields + model_name: str + source: LLMSource = LLMSource.HUGGING_FACE + inference_framework: LLMInferenceFramework = LLMInferenceFramework.VLLM + inference_framework_image_tag: str = "latest" + num_shards: int = 1 + """ + Number of shards to distribute the model onto GPUs. + """ + + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. + """ + + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. + """ + + # General endpoint fields + metadata: Dict[str, Any] # TODO: JSON type + post_inference_hooks: Optional[List[str]] = None + endpoint_type: ModelEndpointType = ModelEndpointType.SYNC + cpus: Optional[CpuSpecificationType] = None + gpus: Optional[int] = None + memory: Optional[StorageSpecificationType] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + nodes_per_worker: Optional[int] = None + optimize_costs: Optional[bool] = None + min_workers: int + max_workers: int + per_worker: int + labels: Dict[str, str] + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrlStr] = None + default_callback_auth: Optional[CallbackAuth] = None + public_inference: Optional[bool] = True # LLM endpoints are public by default. + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) + + +class CreateLLMModelEndpointV1Response(BaseModel): + endpoint_creation_task_id: str + + +class GetLLMModelEndpointV1Response(BaseModel): + id: str + """ + The autogenerated ID of the Launch endpoint. + """ + + name: str + model_name: str + source: LLMSource + status: ModelEndpointStatus + inference_framework: LLMInferenceFramework + inference_framework_image_tag: Optional[str] = None + num_shards: Optional[int] = None + quantize: Optional[Quantization] = None + checkpoint_path: Optional[str] = None + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) + spec: Optional[GetModelEndpointV1Response] = None + + +class ListLLMModelEndpointsV1Response(BaseModel): + model_endpoints: List[GetLLMModelEndpointV1Response] + + +class UpdateLLMModelEndpointV1Request(VLLMEndpointAdditionalArgs, BaseModel): + # LLM specific fields + model_name: Optional[str] = None + source: Optional[LLMSource] = None + inference_framework_image_tag: Optional[str] = None + num_shards: Optional[int] = None + """ + Number of shards to distribute the model onto GPUs. + """ + + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. + """ + + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. + """ + + # General endpoint fields + metadata: Optional[Dict[str, Any]] = None + post_inference_hooks: Optional[List[str]] = None + cpus: Optional[CpuSpecificationType] = None + gpus: Optional[int] = None + memory: Optional[StorageSpecificationType] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + optimize_costs: Optional[bool] = None + min_workers: Optional[int] = None + max_workers: Optional[int] = None + per_worker: Optional[int] = None + labels: Optional[Dict[str, str]] = None + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrlStr] = None + default_callback_auth: Optional[CallbackAuth] = None + public_inference: Optional[bool] = None + chat_template_override: Optional[str] = Field( + default=None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) + + force_bundle_recreation: Optional[bool] = False + """ + Whether to force recreate the underlying bundle. + + If True, the underlying bundle will be recreated. This is useful if there are underlying implementation changes with how bundles are created + that we would like to pick up for existing endpoints + """ + + +class UpdateLLMModelEndpointV1Response(BaseModel): + endpoint_creation_task_id: str + + +class CreateFineTuneRequest(BaseModel): + model: str + training_file: str + validation_file: Optional[str] = None + # fine_tuning_method: str # TODO enum + uncomment when we support multiple methods + hyperparameters: Dict[str, FineTuneHparamValueType] # validated somewhere else + suffix: Optional[str] = None + wandb_config: Optional[Dict[str, Any]] = None + """ + Config to pass to wandb for init. See https://docs.wandb.ai/ref/python/init + Must include `api_key` field which is the wandb API key. + """ + + +class CreateFineTuneResponse(BaseModel): + id: str + + +class GetFineTuneResponse(BaseModel): + id: str = Field(..., description="Unique ID of the fine tune") + fine_tuned_model: Optional[str] = Field( + default=None, + description="Name of the resulting fine-tuned model. This can be plugged into the " + "Completion API ones the fine-tune is complete", + ) + status: BatchJobStatus = Field(..., description="Status of the requested fine tune.") + + +class ListFineTunesResponse(BaseModel): + jobs: List[GetFineTuneResponse] + + +class CancelFineTuneResponse(BaseModel): + success: bool + + +class GetFineTuneEventsResponse(BaseModel): + # LLMFineTuneEvent is entity layer technically, but it's really simple + events: List[LLMFineTuneEvent] + + +class ModelDownloadRequest(BaseModel): + model_name: str = Field(..., description="Name of the fine tuned model") + download_format: Optional[str] = Field( + default="hugging_face", + description="Format that you want the downloaded urls to be compatible with. Currently only supports hugging_face", + ) + + +class ModelDownloadResponse(BaseModel): + urls: Dict[str, str] = Field( + ..., + description="Dictionary of (file_name, url) pairs to download the model from.", + ) + + +# Delete uses the default Launch endpoint APIs. +class DeleteLLMEndpointResponse(BaseModel): + deleted: bool diff --git a/model-engine/model_engine_server/common/dtos/llms/vllm.py b/model-engine/model_engine_server/common/dtos/llms/vllm.py new file mode 100644 index 00000000..af6a6710 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/llms/vllm.py @@ -0,0 +1,339 @@ +from typing import Any, Dict, List, Optional + +from model_engine_server.common.pydantic_types import BaseModel, Field +from model_engine_server.common.types.gen.openai import ( + ResponseFormatJsonObject, + ResponseFormatJsonSchema, + ResponseFormatText, +) +from typing_extensions import Annotated + +# This was last synced w/ vLLM v0.5.5 on 2024-09-03 + + +class VLLMModelConfig(BaseModel): + """Model configuration for VLLM""" + + max_model_len: Optional[int] = Field( + None, + description="""Model context length, If unspecified, will be automatically derived from the model config""", + ) + + max_num_seqs: Optional[int] = Field( + None, + description="""Maximum number of sequences per iteration""", + ) + + enforce_eager: Optional[bool] = Field( + None, + description="""Always use eager-mode PyTorch. If False, will use eager mode and CUDA graph in hybrid for maximal perforamnce and flexibility""", + ) + + gpu_memory_utilization: Optional[float] = Field( + None, + description="Maximum GPU memory utilization use for the engine. Default to 90%.", + ) + + trust_remote_code: Optional[bool] = Field( + default=False, + description="Whether to trust remote code from Hugging face hub. This is only applicable to models whose code is not supported natively by the transformers library (e.g. deepseek). Default to False.", + ) + + pipeline_parallel_size: Optional[int] = Field( + None, + description="Number of pipeline stages. Default to None.", + ) + + tensor_parallel_size: Optional[int] = Field( + None, + description="Number of tensor parallel replicas. Default to None.", + ) + + quantization: Optional[str] = Field( + None, + description="Method used to quantize the weights. If " + "None, we first check the `quantization_config` " + "attribute in the model config file. If that is " + "None, we assume the model weights are not " + "quantized and use `dtype` to determine the data " + "type of the weights.", + ) + + disable_log_requests: Optional[bool] = Field( + None, + description="Disable logging requests. Default to None.", + ) + + chat_template: Optional[str] = Field( + None, + description="A Jinja template to use for this endpoint. If not provided, will use the chat template from the checkpoint", + ) + + tool_call_parser: Optional[str] = Field( + None, + description="Tool call parser", + ) + + enable_auto_tool_choice: Optional[bool] = Field( + None, + description="Enable auto tool choice", + ) + + load_format: Optional[str] = Field( + None, + description="The format of the model weights to load.\n\n" + '* "auto" will try to load the weights in the safetensors format ' + "and fall back to the pytorch bin format if safetensors format " + "is not available.\n" + '* "pt" will load the weights in the pytorch bin format.\n' + '* "safetensors" will load the weights in the safetensors format.\n' + '* "npcache" will load the weights in pytorch format and store ' + "a numpy cache to speed up the loading.\n" + '* "dummy" will initialize the weights with random values, ' + "which is mainly for profiling.\n" + '* "tensorizer" will load the weights using tensorizer from ' + "CoreWeave. See the Tensorize vLLM Model script in the Examples " + "section for more information.\n" + '* "bitsandbytes" will load the weights using bitsandbytes ' + "quantization.\n", + ) + + config_format: Optional[str] = Field( + None, + description="The config format which shall be loaded. Defaults to 'auto' which defaults to 'hf'.", + ) + + tokenizer_mode: Optional[str] = Field( + None, + description="Tokenizer mode. 'auto' will use the fast tokenizer if" + "available, 'slow' will always use the slow tokenizer, and" + "'mistral' will always use the tokenizer from `mistral_common`.", + ) + + limit_mm_per_prompt: Optional[str] = Field( + None, + description="Maximum number of data instances per modality per prompt. Only applicable for multimodal models.", + ) + + enable_prefix_caching: Optional[bool] = Field( + None, + description="Enables automatic prefix caching.", + ) + + max_num_batched_tokens: Optional[int] = Field( + None, description="Maximum number of batched tokens per iteration" + ) + + +class VLLMEngineAdditionalArgs(BaseModel): + """Additional arguments to configure for vLLM that are not direct inputs to the vLLM engine""" + + max_gpu_memory_utilization: Optional[float] = Field( + None, + description="Maximum GPU memory utilization for the batch inference. Default to 90%. Deprecated in favor of specifying this in VLLMModelConfig", + ) + + attention_backend: Optional[str] = Field( + default=None, + description="Attention backend to use for vLLM. Default to None.", + ) + + +class VLLMEndpointAdditionalArgs(VLLMModelConfig, VLLMEngineAdditionalArgs, BaseModel): + pass + + +class VLLMSamplingParams(BaseModel): + best_of: Optional[int] = Field( + None, + description="""Number of output sequences that are generated from the prompt. + From these `best_of` sequences, the top `n` sequences are returned. + `best_of` must be greater than or equal to `n`. This is treated as + the beam width when `use_beam_search` is True. By default, `best_of` + is set to `n`.""", + ) + top_k: Annotated[ + Optional[int], + Field( + None, + ge=-1, + description="Controls the number of top tokens to consider. -1 means consider all tokens.", + ), + ] + min_p: Optional[float] = Field( + None, + description="""Float that represents the minimum probability for a token to be + considered, relative to the probability of the most likely token. + Must be in [0, 1]. Set to 0 to disable this.""", + ) + use_beam_search: Optional[bool] = Field( + None, + description="""Whether to use beam search for sampling.""", + ) + length_penalty: Optional[float] = Field( + default=None, + description="""Float that penalizes sequences based on their length. + Used in beam search.""", + ) + repetition_penalty: Optional[float] = Field( + default=None, + description="""Float that penalizes new tokens based on whether + they appear in the prompt and the generated text so far. Values > 1 + encourage the model to use new tokens, while values < 1 encourage + the model to repeat tokens.""", + ) + early_stopping: Optional[bool] = Field( + None, + description="""Controls the stopping condition for beam search. It + accepts the following values: `True`, where the generation stops as + soon as there are `best_of` complete candidates; `False`, where an + heuristic is applied and the generation stops when is it very + unlikely to find better candidates; `"never"`, where the beam search + procedure only stops when there cannot be better candidates + (canonical beam search algorithm).""", + ) + stop_token_ids: Optional[List[int]] = Field( + default_factory=list, + description="""List of tokens that stop the generation when they are + generated. The returned output will contain the stop tokens unless + the stop tokens are special tokens.""", + ) + include_stop_str_in_output: Annotated[ + Optional[bool], + Field( + None, + description="""Whether to include the stop strings in + output text. Defaults to False.""", + ), + ] + ignore_eos: Optional[bool] = Field( + None, + description="""Whether to ignore the EOS token and continue generating + tokens after the EOS token is generated.""", + ) + min_tokens: Optional[int] = Field( + None, + description="""Minimum number of tokens to generate per output sequence + before EOS or stop_token_ids can be generated""", + ) + + skip_special_tokens: Optional[bool] = Field( + True, + description="Whether to skip special tokens in the output. Only supported in vllm.", + ) + + spaces_between_special_tokens: Optional[bool] = Field( + True, + description="Whether to add spaces between special tokens in the output. Only supported in vllm.", + ) + + +class VLLMChatCompletionAdditionalParams(VLLMSamplingParams): + chat_template: Optional[str] = Field( + default=None, + description=( + "A Jinja template to use for this conversion. " + "As of transformers v4.44, default chat template is no longer " + "allowed, so you must provide a chat template if the model's tokenizer " + "does not define one and no override template is given" + ), + ) + chat_template_kwargs: Optional[Dict[str, Any]] = Field( + default=None, + description=( + "Additional kwargs to pass to the template renderer. " + "Will be accessible by the chat template." + ), + ) + + guided_json: Optional[Dict[str, Any]] = Field( + default=None, + description="JSON schema for guided decoding. Only supported in vllm.", + ) + + guided_regex: Optional[str] = Field( + default=None, + description="Regex for guided decoding. Only supported in vllm.", + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description="Choices for guided decoding. Only supported in vllm.", + ) + + guided_grammar: Optional[str] = Field( + default=None, + description="Context-free grammar for guided decoding. Only supported in vllm.", + ) + + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'" + ), + ) + + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding." + ), + ) + + +class VLLMCompletionAdditionalParams(VLLMSamplingParams): + add_special_tokens: Optional[bool] = Field( + default=None, + description=( + "If true (the default), special tokens (e.g. BOS) will be added to " "the prompt." + ), + ) + + response_format: Optional[ + ResponseFormatText | ResponseFormatJsonObject | ResponseFormatJsonSchema + ] = Field( + default=None, + description=( + "Similar to chat completion, this parameter specifies the format of " + "output. Only {'type': 'json_object'} or {'type': 'text' } is " + "supported." + ), + ) + + guided_json: Optional[Dict[str, Any]] = Field( + default=None, + description="JSON schema for guided decoding. Only supported in vllm.", + ) + + guided_regex: Optional[str] = Field( + default=None, + description="Regex for guided decoding. Only supported in vllm.", + ) + guided_choice: Optional[List[str]] = Field( + default=None, + description="Choices for guided decoding. Only supported in vllm.", + ) + + guided_grammar: Optional[str] = Field( + default=None, + description="Context-free grammar for guided decoding. Only supported in vllm.", + ) + + guided_decoding_backend: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default guided decoding backend " + "of the server for this specific request. If set, must be either " + "'outlines' / 'lm-format-enforcer'" + ), + ) + + guided_whitespace_pattern: Optional[str] = Field( + default=None, + description=( + "If specified, will override the default whitespace pattern " + "for guided json decoding." + ), + ) diff --git a/server/llm_engine_server/common/dtos/model_bundles.py b/model-engine/model_engine_server/common/dtos/model_bundles.py similarity index 75% rename from server/llm_engine_server/common/dtos/model_bundles.py rename to model-engine/model_engine_server/common/dtos/model_bundles.py index 5d3ece47..99d1e13b 100644 --- a/server/llm_engine_server/common/dtos/model_bundles.py +++ b/model-engine/model_engine_server/common/dtos/model_bundles.py @@ -1,16 +1,17 @@ """ Contains various input and output types relating to Model Bundles for the server. """ + import datetime from enum import Enum from typing import Any, Dict, List, Optional -from llm_engine_server.domain.entities import ( +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field +from model_engine_server.domain.entities import ( ModelBundleEnvironmentParams, ModelBundleFlavors, ModelBundlePackagingType, ) -from pydantic import BaseModel, Field class CreateModelBundleV1Request(BaseModel): @@ -23,9 +24,9 @@ class CreateModelBundleV1Request(BaseModel): requirements: List[str] env_params: ModelBundleEnvironmentParams packaging_type: ModelBundlePackagingType - metadata: Optional[Dict[str, Any]] - app_config: Optional[Dict[str, Any]] - schema_location: Optional[str] + metadata: Optional[Dict[str, Any]] = None + app_config: Optional[Dict[str, Any]] = None + schema_location: Optional[str] = None class CloneModelBundleV1Request(BaseModel): @@ -38,7 +39,7 @@ class CloneModelBundleV1Request(BaseModel): The ID of the ModelBundle to copy from. """ - new_app_config: Optional[Dict[str, Any]] + new_app_config: Optional[Dict[str, Any]] = None """ The app_config of the new ModelBundle. If not specified, then the new ModelBundle will use the same app_config as the original. @@ -50,6 +51,8 @@ class CreateModelBundleV1Response(BaseModel): Response object for creating a Model Bundle. """ + model_config = ConfigDict(protected_namespaces=()) + model_bundle_id: str @@ -58,6 +61,8 @@ class ModelBundleV1Response(BaseModel): Response object for a single Model Bundle. """ + model_config = ConfigDict(from_attributes=True, protected_namespaces=()) + id: str name: str location: str @@ -65,17 +70,10 @@ class ModelBundleV1Response(BaseModel): env_params: ModelBundleEnvironmentParams packaging_type: ModelBundlePackagingType metadata: Dict[str, Any] - app_config: Optional[Dict[str, Any]] + app_config: Optional[Dict[str, Any]] = None created_at: datetime.datetime model_artifact_ids: List[str] - schema_location: Optional[str] - - class Config: - """ - ModelBundleResponse Config class. - """ - - orm_mode = True + schema_location: Optional[str] = None class ListModelBundlesV1Response(BaseModel): @@ -83,6 +81,8 @@ class ListModelBundlesV1Response(BaseModel): Response object for listing Model Bundles. """ + model_config = ConfigDict(protected_namespaces=()) + model_bundles: List[ModelBundleV1Response] @@ -92,7 +92,7 @@ class CreateModelBundleV2Request(BaseModel): """ name: str - metadata: Optional[Dict[str, Any]] + metadata: Optional[Dict[str, Any]] = None schema_location: str flavor: ModelBundleFlavors = Field(..., discriminator="flavor") @@ -107,7 +107,7 @@ class CloneModelBundleV2Request(BaseModel): The ID of the ModelBundle to copy from. """ - new_app_config: Optional[Dict[str, Any]] + new_app_config: Optional[Dict[str, Any]] = None """ The app_config of the new ModelBundle. If not specified, then the new ModelBundle will use the same app_config as the original. @@ -119,6 +119,8 @@ class CreateModelBundleV2Response(BaseModel): Response object for creating a Model Bundle. """ + model_config = ConfigDict(protected_namespaces=()) + model_bundle_id: str @@ -127,27 +129,24 @@ class ModelBundleV2Response(BaseModel): Response object for a single Model Bundle. """ + model_config = ConfigDict(from_attributes=True, protected_namespaces=()) + id: str name: str metadata: Dict[str, Any] created_at: datetime.datetime model_artifact_ids: List[str] - schema_location: Optional[str] + schema_location: Optional[str] = None flavor: ModelBundleFlavors = Field(..., discriminator="flavor") - class Config: - """ - ModelBundleResponse Config class. - """ - - orm_mode = True - class ListModelBundlesV2Response(BaseModel): """ Response object for listing Model Bundles. """ + model_config = ConfigDict(protected_namespaces=()) + model_bundles: List[ModelBundleV2Response] diff --git a/server/llm_engine_server/common/dtos/model_endpoints.py b/model-engine/model_engine_server/common/dtos/model_endpoints.py similarity index 68% rename from server/llm_engine_server/common/dtos/model_endpoints.py rename to model-engine/model_engine_server/common/dtos/model_endpoints.py index 956e8ee1..8e6f929e 100644 --- a/server/llm_engine_server/common/dtos/model_endpoints.py +++ b/model-engine/model_engine_server/common/dtos/model_endpoints.py @@ -10,7 +10,9 @@ from enum import Enum from typing import Any, Dict, List, Optional -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.core import HttpUrlStr +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field +from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, GpuType, @@ -21,7 +23,6 @@ ModelEndpointType, StorageSpecificationType, ) -from pydantic import BaseModel, Field, HttpUrl class BrokerType(str, Enum): @@ -32,6 +33,7 @@ class BrokerType(str, Enum): REDIS = "redis" REDIS_24H = "redis_24h" SQS = "sqs" + SERVICEBUS = "servicebus" class BrokerName(str, Enum): @@ -42,6 +44,7 @@ class BrokerName(str, Enum): REDIS = "redis-message-broker-master" SQS = "sqs-message-broker-master" + SERVICEBUS = "servicebus-message-broker-master" class CreateModelEndpointV1Request(BaseModel): @@ -49,21 +52,23 @@ class CreateModelEndpointV1Request(BaseModel): model_bundle_id: str endpoint_type: ModelEndpointType metadata: Dict[str, Any] # TODO: JSON type - post_inference_hooks: Optional[List[str]] + post_inference_hooks: Optional[List[str]] = None cpus: CpuSpecificationType gpus: int = Field(..., ge=0) memory: StorageSpecificationType - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] + gpu_type: Optional[GpuType] = None + storage: StorageSpecificationType + nodes_per_worker: int = Field(gt=0, default=1) + optimize_costs: Optional[bool] = None min_workers: int = Field(..., ge=0) max_workers: int = Field(..., ge=0) per_worker: int = Field(..., gt=0) labels: Dict[str, str] - prewarm: Optional[bool] - high_priority: Optional[bool] - default_callback_url: Optional[HttpUrl] - default_callback_auth: Optional[CallbackAuth] + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrlStr] = None + default_callback_auth: Optional[CallbackAuth] = None public_inference: Optional[bool] = Field(default=False) @@ -72,24 +77,25 @@ class CreateModelEndpointV1Response(BaseModel): class UpdateModelEndpointV1Request(BaseModel): - model_bundle_id: Optional[str] - metadata: Optional[Dict[str, Any]] # TODO: JSON type - post_inference_hooks: Optional[List[str]] - cpus: Optional[CpuSpecificationType] + model_bundle_id: Optional[str] = None + metadata: Optional[Dict[str, Any]] = None # TODO: JSON type + post_inference_hooks: Optional[List[str]] = None + cpus: Optional[CpuSpecificationType] = None gpus: Optional[int] = Field(default=None, ge=0) - memory: Optional[StorageSpecificationType] - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] + memory: Optional[StorageSpecificationType] = None + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + optimize_costs: Optional[bool] = None min_workers: Optional[int] = Field(default=None, ge=0) max_workers: Optional[int] = Field(default=None, ge=0) per_worker: Optional[int] = Field(default=None, gt=0) - labels: Optional[Dict[str, str]] - prewarm: Optional[bool] - high_priority: Optional[bool] - default_callback_url: Optional[HttpUrl] - default_callback_auth: Optional[CallbackAuth] - public_inference: Optional[bool] + labels: Optional[Dict[str, str]] = None + prewarm: Optional[bool] = None + high_priority: Optional[bool] = None + billing_tags: Optional[Dict[str, Any]] = None + default_callback_url: Optional[HttpUrlStr] = None + default_callback_auth: Optional[CallbackAuth] = None + public_inference: Optional[bool] = None class UpdateModelEndpointV1Response(BaseModel): @@ -106,7 +112,7 @@ class GetModelEndpointV1Response(BaseModel): bundle_name: str status: ModelEndpointStatus post_inference_hooks: Optional[List[str]] = Field(default=None) - default_callback_url: Optional[HttpUrl] = Field(default=None) + default_callback_url: Optional[HttpUrlStr] = Field(default=None) default_callback_auth: Optional[CallbackAuth] = Field(default=None) labels: Optional[Dict[str, str]] = Field(default=None) aws_role: Optional[str] = Field(default=None) @@ -139,6 +145,7 @@ class ModelEndpointOrderBy(str, Enum): class GetModelEndpointsSchemaV1Response(BaseModel): + model_config = ConfigDict(protected_namespaces=()) model_endpoints_schema: ModelEndpointsSchema diff --git a/model-engine/model_engine_server/common/dtos/resource_manager.py b/model-engine/model_engine_server/common/dtos/resource_manager.py new file mode 100644 index 00000000..f173a1a8 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/resource_manager.py @@ -0,0 +1,7 @@ +from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest +from model_engine_server.common.pydantic_types import BaseModel + + +class CreateOrUpdateResourcesRequest(BaseModel): + build_endpoint_request: BuildEndpointRequest + image: str diff --git a/server/llm_engine_server/common/dtos/tasks.py b/model-engine/model_engine_server/common/dtos/tasks.py similarity index 58% rename from server/llm_engine_server/common/dtos/tasks.py rename to model-engine/model_engine_server/common/dtos/tasks.py index ecd01802..98335277 100644 --- a/server/llm_engine_server/common/dtos/tasks.py +++ b/model-engine/model_engine_server/common/dtos/tasks.py @@ -5,16 +5,16 @@ from enum import Enum from typing import Any, Optional -from llm_engine_server.domain.entities import CallbackAuth -from pydantic import BaseModel +from model_engine_server.common.pydantic_types import BaseModel, Field, RootModel +from model_engine_server.domain.entities import CallbackAuth -class ResponseSchema(BaseModel): - __root__: Any +class ResponseSchema(RootModel): + root: Any -class RequestSchema(BaseModel): - __root__: Any +class RequestSchema(RootModel): + root: Any class TaskStatus(str, Enum): @@ -49,3 +49,11 @@ class EndpointPredictV1Request(BaseModel): callback_url: Optional[str] = None callback_auth: Optional[CallbackAuth] = None return_pickled: bool = False + destination_path: Optional[str] = None + + +class SyncEndpointPredictV1Request(EndpointPredictV1Request): + timeout_seconds: Optional[float] = Field(default=None, gt=0) + num_retries: Optional[int] = Field(default=None, ge=0) + # See live_{sync,streaming}_model_endpoint_inference_gateway to see how timeout_seconds/num_retries interact. + # Also these fields are only relevant for sync endpoints diff --git a/model-engine/model_engine_server/common/dtos/triggers.py b/model-engine/model_engine_server/common/dtos/triggers.py new file mode 100644 index 00000000..a7cf2750 --- /dev/null +++ b/model-engine/model_engine_server/common/dtos/triggers.py @@ -0,0 +1,50 @@ +""" +Contains various input and output types relating to Triggers for the server. +""" + +import datetime +from typing import Any, Dict, List, Optional + +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field + + +class CreateTriggerV1Request(BaseModel): + name: str + cron_schedule: str + bundle_id: str + default_job_config: Optional[Dict[str, Any]] = None + default_job_metadata: Optional[Dict[str, str]] = None + + +class CreateTriggerV1Response(BaseModel): + trigger_id: str + + +class GetTriggerV1Response(BaseModel): + id: str + name: str + owner: str + created_by: str + created_at: datetime.datetime + cron_schedule: str + docker_image_batch_job_bundle_id: str + default_job_config: Optional[Dict[str, Any]] = Field(default=None) + default_job_metadata: Optional[Dict[str, str]] = Field(default=None) + model_config = ConfigDict(from_attributes=True) + + +class ListTriggersV1Response(BaseModel): + triggers: List[GetTriggerV1Response] + + +class UpdateTriggerV1Request(BaseModel): + cron_schedule: Optional[str] = None + suspend: Optional[bool] = None + + +class UpdateTriggerV1Response(BaseModel): + success: bool + + +class DeleteTriggerV1Response(BaseModel): + success: bool diff --git a/model-engine/model_engine_server/common/env_vars.py b/model-engine/model_engine_server/common/env_vars.py new file mode 100644 index 00000000..2a69cbff --- /dev/null +++ b/model-engine/model_engine_server/common/env_vars.py @@ -0,0 +1,80 @@ +""" +A place for defining, setting, and referencing all environment variables used in Launch. +""" + +import os +import sys +from typing import Optional, Sequence + +from model_engine_server.common.constants import PROJECT_ROOT +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger + +__all__: Sequence[str] = ( + "CIRCLECI", + "GIT_TAG", + "LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH", + "LAUNCH_SERVICE_TEMPLATE_FOLDER", + "LOCAL", + "SKIP_AUTH", + "WORKSPACE", + "get_boolean_env_var", +) + +logger = make_logger(logger_name()) + + +def get_boolean_env_var(name: str) -> bool: + """For all env vars that are either on or off. + + An env var is ON iff: + - it is defined + - its value is the literal string 'true' + + If it is present but not set to 'true', it is considered to be OFF. + """ + value = os.environ.get(name) + if value is None: + return False + value = value.strip().lower() + return "true" == value + + +CIRCLECI: bool = get_boolean_env_var("CIRCLECI") + +LOCAL: bool = get_boolean_env_var("LOCAL") +"""Indicates that Launch is running in a local development environment. Also used for local testing. +""" + +SKIP_AUTH: bool = get_boolean_env_var("SKIP_AUTH") or infra_config().identity_service_url is None +"""Indicates that Launch is running in a development environment where authentication is not +required. +""" + +WORKSPACE: str = os.environ.get("WORKSPACE", "~/models") +"""The working directory where hosted_model_inference is installed. +""" + +LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH: str = os.environ.get( + "LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH", + os.path.join( + PROJECT_ROOT, + "model_engine_server/infra/gateways/resources/templates", + "service_template_config_map_circleci.yaml", + ), +) +"""The path to the config map containing the Launch service template. +""" +logger.info(f"{LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH=}") + +LAUNCH_SERVICE_TEMPLATE_FOLDER: Optional[str] = os.environ.get("LAUNCH_SERVICE_TEMPLATE_FOLDER") +"""The path to the folder containing the Launch service template. If set, this overrides +LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH. +""" + +if LOCAL: + logger.warning("LOCAL development & testing mode is ON") + +GIT_TAG: str = os.environ.get("GIT_TAG", "GIT_TAG_NOT_FOUND") +if GIT_TAG == "GIT_TAG_NOT_FOUND" and "pytest" not in sys.modules: + raise ValueError("GIT_TAG environment variable must be set") diff --git a/server/llm_engine_server/common/errors.py b/model-engine/model_engine_server/common/errors.py similarity index 100% rename from server/llm_engine_server/common/errors.py rename to model-engine/model_engine_server/common/errors.py diff --git a/model-engine/model_engine_server/common/io.py b/model-engine/model_engine_server/common/io.py new file mode 100644 index 00000000..c9d9458f --- /dev/null +++ b/model-engine/model_engine_server/common/io.py @@ -0,0 +1,33 @@ +"""Launch Input/Output utils.""" + +import os +from typing import Any + +import boto3 +import smart_open +from model_engine_server.core.config import infra_config + + +def open_wrapper(uri: str, mode: str = "rt", **kwargs): + client: Any + cloud_provider: str + # This follows the 5.1.0 smart_open API + try: + cloud_provider = infra_config().cloud_provider + except Exception: + cloud_provider = "aws" + if cloud_provider == "azure": + from azure.identity import DefaultAzureCredential + from azure.storage.blob import BlobServiceClient + + client = BlobServiceClient( + f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net", + DefaultAzureCredential(), + ) + else: + profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE")) + session = boto3.Session(profile_name=profile_name) + client = session.client("s3") + + transport_params = {"client": client} + return smart_open.open(uri, mode, transport_params=transport_params) diff --git a/model-engine/model_engine_server/common/pydantic_types.py b/model-engine/model_engine_server/common/pydantic_types.py new file mode 100644 index 00000000..6768acae --- /dev/null +++ b/model-engine/model_engine_server/common/pydantic_types.py @@ -0,0 +1,122 @@ +from typing import Any, Type, TypeVar + +from pydantic import AnyHttpUrl as PyAnyHttpUrl +from pydantic import AnyUrl as PyAnyUrl +from pydantic import AnyWebsocketUrl as PyAnyWebsocketUrl +from pydantic import BaseModel as PydanticBaseModel +from pydantic import model_validator # noqa: F401 +from pydantic import ConfigDict, Field # noqa: F401 +from pydantic import FileUrl as PyFileUrl +from pydantic import FtpUrl as PyFtpUrl +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler # noqa: F401 +from pydantic import HttpUrl as PyHttpUrl +from pydantic import RootModel, TypeAdapter, ValidationError # noqa: F401 +from pydantic import WebsocketUrl as PyWebsocketUrl +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import CoreSchema, core_schema + + +class BaseModel(PydanticBaseModel): + """Common pydantic configurations for model engine""" + + model_config = ConfigDict(protected_namespaces=()) + + +# See https://github.com/patrsc/pydantic-string-url +# just copied it over cause it was a single file + +"""Pydantic URL types based on strings.""" + + +T = TypeVar("T", bound=PyAnyUrl) + + +class AnyUrl(str): + """Pydantic's AnyUrl based on str.""" + + _pydantic_type = PyAnyUrl + _example_url = "http://www.example.com/" + + def __init__(self, url: str) -> None: + """Initialize.""" + pydantic_url = validate_url(url, self._pydantic_type) + super().__init__() + self.url = pydantic_url + + @classmethod + def __get_pydantic_core_schema__( + cls, + source_type: Any, # pylint: disable=unused-argument + handler: GetCoreSchemaHandler, + ) -> CoreSchema: + """Get pydantic core schema.""" + return core_schema.no_info_after_validator_function(cls._validate, handler(str)) + + @classmethod + def __get_pydantic_json_schema__( + cls, + schema: CoreSchema, + handler: GetJsonSchemaHandler, + ) -> JsonSchemaValue: + """Get pydantic JSON schema.""" + json_schema = handler(schema) + json_schema = handler.resolve_ref_schema(json_schema) + json_schema["format"] = "uri" + json_schema["minLength"] = 1 + json_schema["maxLength"] = 65536 + json_schema["examples"] = [cls._example_url] + return json_schema + + @classmethod + def _validate(cls, __input_value: str) -> "AnyUrl": + return cls(__input_value) + + +def validate_url(s: str, cls: Type[T]) -> T: + """Validate if string has the format of a proper URL or given Pydantic type.""" + # This uses pydantic's class just for validation. + a = TypeAdapter(cls) + url = a.validate_python(s, strict=True) + return url + + +class AnyHttpUrl(AnyUrl): + """Pydantic's AnyHttpUrl based on str.""" + + _pydantic_type = PyAnyHttpUrl + _example_url = "http://www.example.com/" + + +class HttpUrl(AnyUrl): + """Pydantic's HttpUrl based on str.""" + + _pydantic_type = PyHttpUrl + _example_url = "http://www.example.com/" + + +class AnyWebsocketUrl(AnyUrl): + """Pydantic's AnyWebsocketUrl based on str.""" + + _pydantic_type = PyAnyWebsocketUrl + _example_url = "ws://www.example.com/" + + +class WebsocketUrl(AnyUrl): + """Pydantic's WebsocketUrl based on str.""" + + _pydantic_type = PyWebsocketUrl + _example_url = "ws://www.example.com/" + + +class FileUrl(AnyUrl): + """Pydantic's FileUrl based on str.""" + + _pydantic_type = PyFileUrl + _example_url = "file://www.example.com/" + + +class FtpUrl(AnyUrl): + """Pydantic's FtpUrl based on str.""" + + _pydantic_type = PyFtpUrl + _example_url = "ftp://www.example.com/" diff --git a/server/llm_engine_server/common/resource_limits.py b/model-engine/model_engine_server/common/resource_limits.py similarity index 82% rename from server/llm_engine_server/common/resource_limits.py rename to model-engine/model_engine_server/common/resource_limits.py index cadf4001..1ede52de 100644 --- a/server/llm_engine_server/common/resource_limits.py +++ b/model-engine/model_engine_server/common/resource_limits.py @@ -1,18 +1,18 @@ from typing import Optional, Union, cast -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ( +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.entities import ( CpuSpecificationType, GpuType, ModelBundle, StorageSpecificationType, TritonEnhancedRunnableImageFlavor, ) -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) -from llm_engine_server.domain.exceptions import EndpointResourceInvalidRequestException -from llm_engine_server.infra.gateways.k8s_resource_parser import ( +from model_engine_server.domain.exceptions import EndpointResourceInvalidRequestException +from model_engine_server.infra.gateways.k8s_resource_parser import ( format_bytes, parse_cpu_request, parse_mem_request, @@ -34,30 +34,38 @@ ) # Should we allow multi-gpu instances? This allows the largest single-gpu g5dn instance. # p4d.24xlarge, p4de.24xlarge A100_INSTANCE_LIMITS = dict(cpus=95, memory="1000Gi") -STORAGE_LIMIT = "500G" # TODO: figure out an actual limit. +H100_INSTANCE_LIMITS = dict(cpus=191, memory="2000Gi", storage="1300Gi") +H100_1G_20GB_INSTANCE_LIMITS = dict(cpus=47, memory="500Gi") +H100_3G_40GB_INSTANCE_LIMITS = dict(cpus=95, memory="1000Gi") +STORAGE_LIMIT = "640Gi" # TODO: figure out an actual limit. REQUESTS_BY_GPU_TYPE = { None: CPU_INSTANCE_LIMITS, GpuType.NVIDIA_TESLA_T4: T4_INSTANCE_LIMITS, GpuType.NVIDIA_AMPERE_A10: A10_INSTANCE_LIMITS, GpuType.NVIDIA_AMPERE_A100: A100_INSTANCE_LIMITS, + GpuType.NVIDIA_AMPERE_A100E: A100_INSTANCE_LIMITS, + GpuType.NVIDIA_HOPPER_H100: H100_INSTANCE_LIMITS, + GpuType.NVIDIA_HOPPER_H100_1G_20GB: H100_1G_20GB_INSTANCE_LIMITS, + GpuType.NVIDIA_HOPPER_H100_3G_40GB: H100_3G_40GB_INSTANCE_LIMITS, } -FORWARDER_CPU_USAGE = 0.5 -FORWARDER_MEMORY_USAGE = "1Gi" +FORWARDER_CPU_USAGE = 1 +FORWARDER_MEMORY_USAGE = "2Gi" FORWARDER_STORAGE_USAGE = "1G" +FORWARDER_WORKER_COUNT = 2 -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) def validate_resource_requests( - bundle: Union[ModelBundle, DockerImageBatchJobBundle], + bundle: Optional[Union[ModelBundle, DockerImageBatchJobBundle]], cpus: Optional[CpuSpecificationType], memory: Optional[StorageSpecificationType], storage: Optional[StorageSpecificationType], gpus: Optional[int], gpu_type: Optional[GpuType], ) -> None: - """Validates whether cpu/memory requests are reasonable""" + """Validates whether cpu/memory requests are reasonable. Shouldn't need to validate any nodes_per_worker in the multinode case""" if (gpus is None or gpus == 0) and gpu_type is not None: raise EndpointResourceInvalidRequestException( @@ -142,7 +150,10 @@ def validate_resource_requests( if storage <= 0: raise EndpointResourceInvalidRequestException("Requested storage must be positive") - available_storage_for_user = parse_mem_request(STORAGE_LIMIT) + available_storage_for_user = parse_mem_request( + resource_limits.get("storage", STORAGE_LIMIT) # type: ignore + ) + total_available_storage = available_storage_for_user if isinstance(bundle, ModelBundle): storage += parse_mem_request(FORWARDER_STORAGE_USAGE) @@ -157,7 +168,7 @@ def validate_resource_requests( else: storage += parse_mem_request(bundle.flavor.triton_storage) - if storage > parse_mem_request(STORAGE_LIMIT): + if storage > total_available_storage: raise EndpointResourceInvalidRequestException( f"Requested {storage=} too high. The maximum for {gpu_type=} is {format_bytes(available_storage_for_user)}" ) diff --git a/server/llm_engine_server/common/serialization_utils.py b/model-engine/model_engine_server/common/serialization_utils.py similarity index 100% rename from server/llm_engine_server/common/serialization_utils.py rename to model-engine/model_engine_server/common/serialization_utils.py diff --git a/server/llm_engine_server/common/service_requests.py b/model-engine/model_engine_server/common/service_requests.py similarity index 85% rename from server/llm_engine_server/common/service_requests.py rename to model-engine/model_engine_server/common/service_requests.py index 3e20fbe5..d709bdec 100644 --- a/server/llm_engine_server/common/service_requests.py +++ b/model-engine/model_engine_server/common/service_requests.py @@ -3,8 +3,8 @@ from typing import Any, Dict, Optional import requests -from llm_engine_server.common.errors import HTTP429Exception, UpstreamHTTPSvcError -from llm_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.common.errors import HTTP429Exception, UpstreamHTTPSvcError +from model_engine_server.core.loggers import logger_name, make_logger from tenacity import ( RetryError, Retrying, @@ -13,7 +13,7 @@ wait_exponential, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) SYNC_ENDPOINT_RETRIES = 10 # Must be an integer >= 0 SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS = 10 @@ -37,7 +37,8 @@ def make_sync_request_with_retries( wait=wait_exponential(multiplier=1, min=1, max=timeout_seconds), ): with attempt: - logger.info(f"Retry number {attempt.retry_state.attempt_number}") + if attempt.retry_state.attempt_number > 1: # pragma: no cover + logger.info(f"Retry number {attempt.retry_state.attempt_number}") resp = requests.post( request_url, json=payload_json, diff --git a/model-engine/model_engine_server/common/settings.py b/model-engine/model_engine_server/common/settings.py new file mode 100644 index 00000000..7438844a --- /dev/null +++ b/model-engine/model_engine_server/common/settings.py @@ -0,0 +1,114 @@ +# This file contains standard settings for ML serve. +# + +import hashlib +from typing import List, Tuple + +from model_engine_server.common.config import hmi_config +from model_engine_server.core.config import infra_config + +DEPLOYMENT_PREFIX = "launch" +LEGACY_DEPLOYMENT_PREFIX = "hmi" +SERVICE_BUILDER_QUEUE_PREFIX = "model-engine" +SERVICE_BUILDER_QUEUE_SUFFIX = "service-builder" +HOSTED_INFERENCE_SERVER_NAME = "hostedinference" +LAUNCH_SERVER_NAME = "launch" +K8S_CACHER_NAME = "launch-k8s-cacher" +PYSPARK_DEFAULT_ENDPOINT_PARAMS = dict( + cpus=3, + memory="12Gi", + gpus=1, + gpu_type="nvidia-tesla-t4", + min_workers=0, + max_workers=50, + per_worker=40, +) # TODO: we could probably determine an appropriate value for max_workers based on the size of the batch +PYSPARK_DEFAULT_MAX_EXECUTORS = 50 +PYSPARK_DEFAULT_PARTITION_SIZE = 500 + +RESTRICTED_ENDPOINT_LABELS = set( + [ + "user_id", + "endpoint_name", + ] +) + +REQUIRED_ENDPOINT_LABELS = set( + [ + "team", + "product", + ] +) + +PRETRAINED_ENDPOINTS_CREATED_BY = ["nucleus-model-zoo", "bloom", "llm", "pretrained"] + + +def generate_deployment_name(user_id, endpoint_name): + return "-".join(_generate_deployment_name_parts(user_id, endpoint_name)) + + +def _generate_queue_name(user_id, endpoint_name): + return ".".join(_generate_deployment_name_parts(user_id, endpoint_name)) + + +def generate_destination(user_id: str, endpoint_name: str, endpoint_type: str) -> str: + if endpoint_type == "async": + return _generate_queue_name(user_id, endpoint_name) + elif endpoint_type in {"sync", "streaming"}: + return generate_deployment_name(user_id, endpoint_name) + else: + raise ValueError(f"Invalid endpoint_type: {endpoint_type}") + + +def _generate_deployment_name_parts(user_id: str, endpoint_name: str) -> List[str]: + user_endpoint_hash = hashlib.md5((user_id + endpoint_name).encode("utf-8")).hexdigest() + return [ + DEPLOYMENT_PREFIX, + user_id[:24], + endpoint_name[:8], + user_endpoint_hash[:8], + ] + + +def generate_batch_job_name(user_id: str, endpoint_name: str): + batch_job_partial_name = "-".join(_generate_deployment_name_parts(user_id, endpoint_name)) + return f"batch-job-{batch_job_partial_name}" + + +def get_sync_endpoint_hostname_and_url(deployment_name: str) -> Tuple[str, str]: + hostname = f"{deployment_name}.{hmi_config.endpoint_namespace}" + return hostname, f"http://{hostname}/predict" + + +def get_sync_endpoint_elb_url(deployment_name: str) -> str: + return f"http://{deployment_name}.{infra_config().dns_host_domain}/predict" + + +def get_service_builder_queue(service_identifier=None): + return ( + f"{SERVICE_BUILDER_QUEUE_PREFIX}-{service_identifier}-{SERVICE_BUILDER_QUEUE_SUFFIX}" + if service_identifier + else f"{SERVICE_BUILDER_QUEUE_PREFIX}-{SERVICE_BUILDER_QUEUE_SUFFIX}" + ) + + +def get_quart_server_name(service_identifier=None): + return ( + f"{HOSTED_INFERENCE_SERVER_NAME}-{service_identifier}" + if service_identifier + else HOSTED_INFERENCE_SERVER_NAME + ) + + +def get_gateway_server_name(service_identifier=None): + return ( + f"{LAUNCH_SERVER_NAME}-{service_identifier}" if service_identifier else LAUNCH_SERVER_NAME + ) + + +def get_service_builder_logs_location(user_id: str, endpoint_name: str): + return f"s3://{infra_config().s3_bucket}/service_builder_logs/{user_id}_{endpoint_name}" + + +def get_k8s_cacher_service_name(service_identifier=None): + return f"{K8S_CACHER_NAME}-{service_identifier}" if service_identifier else K8S_CACHER_NAME diff --git a/model-engine/model_engine_server/common/types/__init__.py b/model-engine/model_engine_server/common/types/__init__.py new file mode 100644 index 00000000..5cdfe557 --- /dev/null +++ b/model-engine/model_engine_server/common/types/__init__.py @@ -0,0 +1,2 @@ +from .endpoint import * # noqa: F403 +from .gen import * # noqa: F403 diff --git a/server/llm_engine_server/common/types.py b/model-engine/model_engine_server/common/types/endpoint.py similarity index 98% rename from server/llm_engine_server/common/types.py rename to model-engine/model_engine_server/common/types/endpoint.py index 22508887..93ccbed6 100644 --- a/server/llm_engine_server/common/types.py +++ b/model-engine/model_engine_server/common/types/endpoint.py @@ -117,3 +117,4 @@ class EndpointBuilderParams(EndpointParams): app_config: Optional[Dict[str, Any]] = None child_fn_info: Optional[Dict[str, Any]] = None post_inference_hooks: Optional[List[str]] = None + billing_tags: Optional[Dict[str, Any]] = None diff --git a/model-engine/model_engine_server/common/types/gen/openai.py b/model-engine/model_engine_server/common/types/gen/openai.py new file mode 100644 index 00000000..c206d98f --- /dev/null +++ b/model-engine/model_engine_server/common/types/gen/openai.py @@ -0,0 +1,5964 @@ +# generated by datamodel-codegen: +# filename: openai-spec.yaml +# timestamp: 2024-10-15T23:20:07+00:00 + +from __future__ import annotations + +from typing import Any, Dict, List, Literal, Optional, Union + +from model_engine_server.common.pydantic_types import ( + AnyUrl, + BaseModel, + ConfigDict, + Field, + RootModel, +) +from typing_extensions import Annotated + + +class Error(BaseModel): + code: Annotated[Optional[str], Field(...)] = None + message: str + param: Annotated[Optional[str], Field(...)] = None + type: str + + +class ErrorResponse(BaseModel): + error: Error + + +class DeleteModelResponse(BaseModel): + id: str + deleted: bool + object: str + + +class Prompt(RootModel[Optional[List[int]]]): + root: Annotated[ + Optional[List[int]], + Field( + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", + examples=["[1212, 318, 257, 1332, 13]"], + min_length=1, + ), + ] = "<|endoftext|>" + + +class Prompt1Item(RootModel[List[int]]): + root: Annotated[List[int], Field(min_length=1)] + + +class Prompt1(RootModel[Optional[List[Prompt1Item]]]): + root: Annotated[ + Optional[List[Prompt1Item]], + Field( + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n", + examples=["[[1212, 318, 257, 1332, 13]]"], + min_length=1, + ), + ] = "<|endoftext|>" + + +class Stop(RootModel[Optional[List[str]]]): + root: Annotated[ + Optional[List[str]], + Field( + description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n", + max_length=4, + min_length=1, + ), + ] = None + + +class Logprobs(BaseModel): + text_offset: Optional[List[int]] = None + token_logprobs: Optional[List[float]] = None + tokens: Optional[List[str]] = None + top_logprobs: Optional[List[Dict[str, float]]] = None + + +class Choice(BaseModel): + finish_reason: Annotated[ + Optional[Literal["stop", "length", "content_filter"]], + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\nor `content_filter` if content was omitted due to a flag from our content filters.\n" + ), + ] + index: int + logprobs: Annotated[Optional[Logprobs], Field(...)] + text: str + + +class ChatCompletionRequestMessageContentPartText(BaseModel): + type: Annotated[Literal["text"], Field(description="The type of the content part.")] + text: Annotated[str, Field(description="The text content.")] + + +class ImageUrl(BaseModel): + url: Annotated[ + AnyUrl, + Field(description="Either a URL of the image or the base64 encoded image data."), + ] + detail: Annotated[ + Literal["auto", "low", "high"], + Field( + description="Specifies the detail level of the image. Learn more in the [Vision guide](/docs/guides/vision/low-or-high-fidelity-image-understanding)." + ), + ] = "auto" + + +class ChatCompletionRequestMessageContentPartImage(BaseModel): + type: Annotated[Literal["image_url"], Field(description="The type of the content part.")] + image_url: ImageUrl + + +class ChatCompletionRequestMessageContentPartRefusal(BaseModel): + type: Annotated[Literal["refusal"], Field(description="The type of the content part.")] + refusal: Annotated[str, Field(description="The refusal message generated by the model.")] + + +class ChatCompletionRequestSystemMessageContentPart( + RootModel[ChatCompletionRequestMessageContentPartText] +): + root: ChatCompletionRequestMessageContentPartText + + +class ChatCompletionRequestUserMessageContentPart( + RootModel[ + Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartImage, + ] + ] +): + root: Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartImage, + ] + + +class ChatCompletionRequestAssistantMessageContentPart( + RootModel[ + Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartRefusal, + ] + ] +): + root: Union[ + ChatCompletionRequestMessageContentPartText, + ChatCompletionRequestMessageContentPartRefusal, + ] + + +class ChatCompletionRequestToolMessageContentPart( + RootModel[ChatCompletionRequestMessageContentPartText] +): + root: ChatCompletionRequestMessageContentPartText + + +class Content(RootModel[List[ChatCompletionRequestSystemMessageContentPart]]): + root: Annotated[ + List[ChatCompletionRequestSystemMessageContentPart], + Field( + description="An array of content parts with a defined type. For system messages, only type `text` is supported.", + min_length=1, + title="Array of content parts", + ), + ] + + +class ChatCompletionRequestSystemMessage(BaseModel): + content: Annotated[ + Union[str, Content], Field(description="The contents of the system message.") + ] + role: Annotated[ + Literal["system"], + Field(description="The role of the messages author, in this case `system`."), + ] + name: Annotated[ + Optional[str], + Field( + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role." + ), + ] = None + + +class Content1(RootModel[List[ChatCompletionRequestUserMessageContentPart]]): + root: Annotated[ + List[ChatCompletionRequestUserMessageContentPart], + Field( + description="An array of content parts with a defined type, each can be of type `text` or `image_url` when passing in images. You can pass multiple images by adding multiple `image_url` content parts. Image input is only supported when using the `gpt-4o` model.", + min_length=1, + title="Array of content parts", + ), + ] + + +class ChatCompletionRequestUserMessage(BaseModel): + content: Annotated[ + Union[str, Content1], Field(description="The contents of the user message.\n") + ] + role: Annotated[ + Literal["user"], + Field(description="The role of the messages author, in this case `user`."), + ] + name: Annotated[ + Optional[str], + Field( + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role." + ), + ] = None + + +class Content2(RootModel[Optional[List[ChatCompletionRequestAssistantMessageContentPart]]]): + root: Annotated[ + Optional[List[ChatCompletionRequestAssistantMessageContentPart]], + Field( + description="An array of content parts with a defined type. Can be one or more of type `text`, or exactly one of type `refusal`.", + min_length=1, + title="Array of content parts", + ), + ] = None + + +class FunctionCall(BaseModel): + arguments: Annotated[ + str, + Field( + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." + ), + ] + name: Annotated[str, Field(description="The name of the function to call.")] + + +class Content3(RootModel[List[ChatCompletionRequestToolMessageContentPart]]): + root: Annotated[ + List[ChatCompletionRequestToolMessageContentPart], + Field( + description="An array of content parts with a defined type. For tool messages, only type `text` is supported.", + min_length=1, + title="Array of content parts", + ), + ] + + +class ChatCompletionRequestToolMessage(BaseModel): + role: Annotated[ + Literal["tool"], + Field(description="The role of the messages author, in this case `tool`."), + ] + content: Annotated[Union[str, Content3], Field(description="The contents of the tool message.")] + tool_call_id: Annotated[str, Field(description="Tool call that this message is responding to.")] + + +class ChatCompletionRequestFunctionMessage(BaseModel): + role: Annotated[ + Literal["function"], + Field(description="The role of the messages author, in this case `function`."), + ] + content: Annotated[ + Optional[str], Field(description="The contents of the function message.") + ] = None + name: Annotated[str, Field(description="The name of the function to call.")] + + +class FunctionParameters(BaseModel): + pass + model_config = ConfigDict( + extra="allow", + ) + + +class ChatCompletionFunctions(BaseModel): + description: Annotated[ + Optional[str], + Field( + description="A description of what the function does, used by the model to choose when and how to call the function." + ), + ] = None + name: Annotated[ + str, + Field( + description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." + ), + ] + parameters: Optional[FunctionParameters] = None + + +class ChatCompletionFunctionCallOption(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + + +class FunctionObject(BaseModel): + description: Annotated[ + Optional[str], + Field( + description="A description of what the function does, used by the model to choose when and how to call the function." + ), + ] = None + name: Annotated[ + str, + Field( + description="The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." + ), + ] + parameters: Optional[FunctionParameters] = None + strict: Annotated[ + Optional[bool], + Field( + description="Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling)." + ), + ] = None + + +class ResponseFormatText(BaseModel): + type: Annotated[ + Literal["text"], + Field(description="The type of response format being defined: `text`"), + ] + + +class ResponseFormatJsonObject(BaseModel): + type: Annotated[ + Literal["json_object"], + Field(description="The type of response format being defined: `json_object`"), + ] + + +class ResponseFormatJsonSchemaSchema(BaseModel): + pass + model_config = ConfigDict( + extra="allow", + ) + + +class JsonSchema(BaseModel): + description: Annotated[ + Optional[str], + Field( + description="A description of what the response format is for, used by the model to determine how to respond in the format." + ), + ] = None + name: Annotated[ + str, + Field( + description="The name of the response format. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64." + ), + ] + schema_: Annotated[Optional[ResponseFormatJsonSchemaSchema], Field(alias="schema")] = None + strict: Annotated[ + Optional[bool], + Field( + description="Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the `schema` field. Only a subset of JSON Schema is supported when `strict` is `true`. To learn more, read the [Structured Outputs guide](/docs/guides/structured-outputs)." + ), + ] = False + + +class ResponseFormatJsonSchema(BaseModel): + type: Annotated[ + Literal["json_schema"], + Field(description="The type of response format being defined: `json_schema`"), + ] + json_schema: JsonSchema + + +class Function(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + + +class ChatCompletionNamedToolChoice(BaseModel): + type: Annotated[ + Literal["function"], + Field(description="The type of the tool. Currently, only `function` is supported."), + ] + function: Function + + +class ParallelToolCalls(RootModel[bool]): + root: Annotated[ + bool, + Field( + description="Whether to enable [parallel function calling](/docs/guides/function-calling/parallel-function-calling) during tool use." + ), + ] + + +class Function1(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + arguments: Annotated[ + str, + Field( + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." + ), + ] + + +class ChatCompletionMessageToolCall(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call.")] + type: Annotated[ + Literal["function"], + Field(description="The type of the tool. Currently, only `function` is supported."), + ] + function: Annotated[Function1, Field(description="The function that the model called.")] + + +class Function2(BaseModel): + name: Annotated[Optional[str], Field(description="The name of the function to call.")] = None + arguments: Annotated[ + Optional[str], + Field( + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." + ), + ] = None + + +class ChatCompletionMessageToolCallChunk(BaseModel): + index: int + id: Annotated[Optional[str], Field(description="The ID of the tool call.")] = None + type: Annotated[ + Optional[Literal["function"]], + Field(description="The type of the tool. Currently, only `function` is supported."), + ] = None + function: Optional[Function2] = None + + +class ChatCompletionRole(RootModel[Literal["system", "user", "assistant", "tool", "function"]]): + root: Annotated[ + Literal["system", "user", "assistant", "tool", "function"], + Field(description="The role of the author of a message"), + ] + + +class ChatCompletionStreamOptions(BaseModel): + include_usage: Annotated[ + Optional[bool], + Field( + description="If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value.\n" + ), + ] = None + + +class FunctionCall2(BaseModel): + arguments: Annotated[ + Optional[str], + Field( + description="The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function." + ), + ] = None + name: Annotated[Optional[str], Field(description="The name of the function to call.")] = None + + +class ChatCompletionStreamResponseDelta(BaseModel): + content: Annotated[Optional[str], Field(description="The contents of the chunk message.")] = ( + None + ) + function_call: Annotated[ + Optional[FunctionCall2], + Field( + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." + ), + ] = None + tool_calls: Optional[List[ChatCompletionMessageToolCallChunk]] = None + role: Annotated[ + Optional[Literal["system", "user", "assistant", "tool"]], + Field(description="The role of the author of this message."), + ] = None + refusal: Annotated[ + Optional[str], Field(description="The refusal message generated by the model.") + ] = None + + +class Stop1(RootModel[List[str]]): + root: Annotated[ + List[str], + Field( + description="Up to 4 sequences where the API will stop generating further tokens.\n", + max_length=4, + min_length=1, + ), + ] + + +class TopLogprob(BaseModel): + token: Annotated[str, Field(description="The token.")] + logprob: Annotated[ + float, + Field( + description="The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely." + ), + ] + bytes: Annotated[ + Optional[List[int]], + Field( + description="A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token." + ), + ] + + +class ChatCompletionTokenLogprob(BaseModel): + token: Annotated[str, Field(description="The token.")] + logprob: Annotated[ + float, + Field( + description="The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely." + ), + ] + bytes: Annotated[ + Optional[List[int]], + Field( + description="A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token." + ), + ] + top_logprobs: Annotated[ + List[TopLogprob], + Field( + description="List of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested `top_logprobs` returned." + ), + ] + + +class Logprobs2(BaseModel): + content: Annotated[ + Optional[List[ChatCompletionTokenLogprob]], + Field(description="A list of message content tokens with log probability information."), + ] + refusal: Annotated[ + Optional[List[ChatCompletionTokenLogprob]], + Field(description="A list of message refusal tokens with log probability information."), + ] = None + + +class Choice3(BaseModel): + delta: ChatCompletionStreamResponseDelta + logprobs: Annotated[ + Optional[Logprobs2], + Field(description="Log probability information for the choice."), + ] = None + finish_reason: Annotated[ + Optional[Literal["stop", "length", "tool_calls", "content_filter", "function_call"]], + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n" + ), + ] + index: Annotated[int, Field(description="The index of the choice in the list of choices.")] + + +class Usage(BaseModel): + completion_tokens: Annotated[ + int, Field(description="Number of tokens in the generated completion.") + ] + prompt_tokens: Annotated[int, Field(description="Number of tokens in the prompt.")] + total_tokens: Annotated[ + int, + Field(description="Total number of tokens used in the request (prompt + completion)."), + ] + + +class CreateChatCompletionStreamResponse(BaseModel): + id: Annotated[ + str, + Field( + description="A unique identifier for the chat completion. Each chunk has the same ID." + ), + ] + choices: Annotated[ + List[Choice3], + Field( + description='A list of chat completion choices. Can contain more than one elements if `n` is greater than 1. Can also be empty for the\nlast chunk if you set `stream_options: {"include_usage": true}`.\n' + ), + ] + created: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp." + ), + ] + model: Annotated[str, Field(description="The model to generate the completion.")] + service_tier: Annotated[ + Optional[Literal["scale", "default"]], + Field( + description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", + examples=["scale"], + ), + ] = None + system_fingerprint: Annotated[ + Optional[str], + Field( + description="This fingerprint represents the backend configuration that the model runs with.\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" + ), + ] = None + object: Annotated[ + Literal["chat.completion.chunk"], + Field(description="The object type, which is always `chat.completion.chunk`."), + ] + usage: Annotated[ + Optional[Usage], + Field( + description='An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request.\nWhen present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request.\n' + ), + ] = None + + +class CreateChatCompletionImageResponse(BaseModel): + pass + + +class CreateImageRequest(BaseModel): + prompt: Annotated[ + str, + Field( + description="A text description of the desired image(s). The maximum length is 1000 characters for `dall-e-2` and 4000 characters for `dall-e-3`.", + examples=["A cute baby sea otter"], + ), + ] + model: Annotated[ + Optional[Union[Optional[str], Literal["dall-e-2", "dall-e-3"]]], + Field(description="The model to use for image generation.", examples=["dall-e-3"]), + ] = "dall-e-2" + n: Annotated[ + Optional[int], + Field( + description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", + examples=[1], + ge=1, + le=10, + ), + ] = 1 + quality: Annotated[ + Literal["standard", "hd"], + Field( + description="The quality of the image that will be generated. `hd` creates images with finer details and greater consistency across the image. This param is only supported for `dall-e-3`.", + examples=["standard"], + ), + ] = "standard" + response_format: Annotated[ + Optional[Literal["url", "b64_json"]], + Field( + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", + examples=["url"], + ), + ] = "url" + size: Annotated[ + Optional[Literal["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"]], + Field( + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models.", + examples=["1024x1024"], + ), + ] = "1024x1024" + style: Annotated[ + Optional[Literal["vivid", "natural"]], + Field( + description="The style of the generated images. Must be one of `vivid` or `natural`. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for `dall-e-3`.", + examples=["vivid"], + ), + ] = "vivid" + user: Annotated[ + Optional[str], + Field( + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] = None + + +class Image(BaseModel): + b64_json: Annotated[ + Optional[str], + Field( + description="The base64-encoded JSON of the generated image, if `response_format` is `b64_json`." + ), + ] = None + url: Annotated[ + Optional[str], + Field( + description="The URL of the generated image, if `response_format` is `url` (default)." + ), + ] = None + revised_prompt: Annotated[ + Optional[str], + Field( + description="The prompt that was used to generate the image, if there was any revision to the prompt." + ), + ] = None + + +class CreateImageEditRequest(BaseModel): + image: Annotated[ + bytes, + Field( + description="The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask is not provided, image must have transparency, which will be used as the mask." + ), + ] + prompt: Annotated[ + str, + Field( + description="A text description of the desired image(s). The maximum length is 1000 characters.", + examples=["A cute baby sea otter wearing a beret"], + ), + ] + mask: Annotated[ + Optional[bytes], + Field( + description="An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`." + ), + ] = None + model: Annotated[ + Optional[Union[Optional[str], Literal["dall-e-2"]]], + Field( + description="The model to use for image generation. Only `dall-e-2` is supported at this time.", + examples=["dall-e-2"], + ), + ] = "dall-e-2" + n: Annotated[ + Optional[int], + Field( + description="The number of images to generate. Must be between 1 and 10.", + examples=[1], + ge=1, + le=10, + ), + ] = 1 + size: Annotated[ + Optional[Literal["256x256", "512x512", "1024x1024"]], + Field( + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", + examples=["1024x1024"], + ), + ] = "1024x1024" + response_format: Annotated[ + Optional[Literal["url", "b64_json"]], + Field( + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", + examples=["url"], + ), + ] = "url" + user: Annotated[ + Optional[str], + Field( + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] = None + + +class CreateImageVariationRequest(BaseModel): + image: Annotated[ + bytes, + Field( + description="The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square." + ), + ] + model: Annotated[ + Optional[Union[Optional[str], Literal["dall-e-2"]]], + Field( + description="The model to use for image generation. Only `dall-e-2` is supported at this time.", + examples=["dall-e-2"], + ), + ] = "dall-e-2" + n: Annotated[ + Optional[int], + Field( + description="The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported.", + examples=[1], + ge=1, + le=10, + ), + ] = 1 + response_format: Annotated[ + Optional[Literal["url", "b64_json"]], + Field( + description="The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated.", + examples=["url"], + ), + ] = "url" + size: Annotated[ + Optional[Literal["256x256", "512x512", "1024x1024"]], + Field( + description="The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`.", + examples=["1024x1024"], + ), + ] = "1024x1024" + user: Annotated[ + Optional[str], + Field( + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] = None + + +class CreateModerationRequest(BaseModel): + input: Annotated[Union[str, List[str]], Field(description="The input text to classify")] + model: Annotated[ + Union[str, Literal["text-moderation-latest", "text-moderation-stable"]], + Field( + description="Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`.\n\nThe default is `text-moderation-latest` which will be automatically upgraded over time. This ensures you are always using our most accurate model. If you use `text-moderation-stable`, we will provide advanced notice before updating the model. Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`.\n", + examples=["text-moderation-stable"], + ), + ] = "text-moderation-latest" + + +class Categories(BaseModel): + hate: Annotated[ + bool, + Field( + description="Content that expresses, incites, or promotes hate based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste. Hateful content aimed at non-protected groups (e.g., chess players) is harassment." + ), + ] + hate_threatening: Annotated[ + bool, + Field( + alias="hate/threatening", + description="Hateful content that also includes violence or serious harm towards the targeted group based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste.", + ), + ] + harassment: Annotated[ + bool, + Field( + description="Content that expresses, incites, or promotes harassing language towards any target." + ), + ] + harassment_threatening: Annotated[ + bool, + Field( + alias="harassment/threatening", + description="Harassment content that also includes violence or serious harm towards any target.", + ), + ] + self_harm: Annotated[ + bool, + Field( + alias="self-harm", + description="Content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders.", + ), + ] + self_harm_intent: Annotated[ + bool, + Field( + alias="self-harm/intent", + description="Content where the speaker expresses that they are engaging or intend to engage in acts of self-harm, such as suicide, cutting, and eating disorders.", + ), + ] + self_harm_instructions: Annotated[ + bool, + Field( + alias="self-harm/instructions", + description="Content that encourages performing acts of self-harm, such as suicide, cutting, and eating disorders, or that gives instructions or advice on how to commit such acts.", + ), + ] + sexual: Annotated[ + bool, + Field( + description="Content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness)." + ), + ] + sexual_minors: Annotated[ + bool, + Field( + alias="sexual/minors", + description="Sexual content that includes an individual who is under 18 years old.", + ), + ] + violence: Annotated[ + bool, + Field(description="Content that depicts death, violence, or physical injury."), + ] + violence_graphic: Annotated[ + bool, + Field( + alias="violence/graphic", + description="Content that depicts death, violence, or physical injury in graphic detail.", + ), + ] + + +class CategoryScores(BaseModel): + hate: Annotated[float, Field(description="The score for the category 'hate'.")] + hate_threatening: Annotated[ + float, + Field( + alias="hate/threatening", + description="The score for the category 'hate/threatening'.", + ), + ] + harassment: Annotated[float, Field(description="The score for the category 'harassment'.")] + harassment_threatening: Annotated[ + float, + Field( + alias="harassment/threatening", + description="The score for the category 'harassment/threatening'.", + ), + ] + self_harm: Annotated[ + float, + Field(alias="self-harm", description="The score for the category 'self-harm'."), + ] + self_harm_intent: Annotated[ + float, + Field( + alias="self-harm/intent", + description="The score for the category 'self-harm/intent'.", + ), + ] + self_harm_instructions: Annotated[ + float, + Field( + alias="self-harm/instructions", + description="The score for the category 'self-harm/instructions'.", + ), + ] + sexual: Annotated[float, Field(description="The score for the category 'sexual'.")] + sexual_minors: Annotated[ + float, + Field( + alias="sexual/minors", + description="The score for the category 'sexual/minors'.", + ), + ] + violence: Annotated[float, Field(description="The score for the category 'violence'.")] + violence_graphic: Annotated[ + float, + Field( + alias="violence/graphic", + description="The score for the category 'violence/graphic'.", + ), + ] + + +class Result(BaseModel): + flagged: Annotated[bool, Field(description="Whether any of the below categories are flagged.")] + categories: Annotated[ + Categories, + Field(description="A list of the categories, and whether they are flagged or not."), + ] + category_scores: Annotated[ + CategoryScores, + Field( + description="A list of the categories along with their scores as predicted by model." + ), + ] + + +class CreateModerationResponse(BaseModel): + id: Annotated[str, Field(description="The unique identifier for the moderation request.")] + model: Annotated[str, Field(description="The model used to generate the moderation results.")] + results: Annotated[List[Result], Field(description="A list of moderation objects.")] + + +class CreateFileRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file: Annotated[bytes, Field(description="The File object (not file name) to be uploaded.\n")] + purpose: Annotated[ + Literal["assistants", "batch", "fine-tune", "vision"], + Field( + description='The intended purpose of the uploaded file.\n\nUse "assistants" for [Assistants](/docs/api-reference/assistants) and [Message](/docs/api-reference/messages) files, "vision" for Assistants image file inputs, "batch" for [Batch API](/docs/guides/batch), and "fine-tune" for [Fine-tuning](/docs/api-reference/fine-tuning).\n' + ), + ] + + +class DeleteFileResponse(BaseModel): + id: str + object: Literal["file"] + deleted: bool + + +class CreateUploadRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + filename: Annotated[str, Field(description="The name of the file to upload.\n")] + purpose: Annotated[ + Literal["assistants", "batch", "fine-tune", "vision"], + Field( + description="The intended purpose of the uploaded file.\n\nSee the [documentation on File purposes](/docs/api-reference/files/create#files-create-purpose).\n" + ), + ] + bytes: Annotated[int, Field(description="The number of bytes in the file you are uploading.\n")] + mime_type: Annotated[ + str, + Field( + description="The MIME type of the file.\n\nThis must fall within the supported MIME types for your file purpose. See the supported MIME types for assistants and vision.\n" + ), + ] + + +class AddUploadPartRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + data: Annotated[bytes, Field(description="The chunk of bytes for this Part.\n")] + + +class CompleteUploadRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + part_ids: Annotated[List[str], Field(description="The ordered list of Part IDs.\n")] + md5: Annotated[ + Optional[str], + Field( + description="The optional md5 checksum for the file contents to verify if the bytes uploaded matches what you expect.\n" + ), + ] = None + + +class CancelUploadRequest(BaseModel): + pass + model_config = ConfigDict( + extra="forbid", + ) + + +class BatchSize(RootModel[int]): + root: Annotated[ + int, + Field( + description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n", + ge=1, + le=256, + ), + ] + + +class LearningRateMultiplier(RootModel[float]): + root: Annotated[ + float, + Field( + description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n", + gt=0.0, + ), + ] + + +class NEpochs(RootModel[int]): + root: Annotated[ + int, + Field( + description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n", + ge=1, + le=50, + ), + ] + + +class Hyperparameters(BaseModel): + batch_size: Annotated[ + Union[Literal["auto"], BatchSize], + Field( + description="Number of examples in each batch. A larger batch size means that model parameters\nare updated less frequently, but with lower variance.\n" + ), + ] = "auto" + learning_rate_multiplier: Annotated[ + Union[Literal["auto"], LearningRateMultiplier], + Field( + description="Scaling factor for the learning rate. A smaller learning rate may be useful to avoid\noverfitting.\n" + ), + ] = "auto" + n_epochs: Annotated[ + Union[Literal["auto"], NEpochs], + Field( + description="The number of epochs to train the model for. An epoch refers to one full cycle\nthrough the training dataset.\n" + ), + ] = "auto" + + +class Wandb(BaseModel): + project: Annotated[ + str, + Field( + description="The name of the project that the new run will be created under.\n", + examples=["my-wandb-project"], + ), + ] + name: Annotated[ + Optional[str], + Field( + description="A display name to set for the run. If not set, we will use the Job ID as the name.\n" + ), + ] = None + entity: Annotated[ + Optional[str], + Field( + description="The entity to use for the run. This allows you to set the team or username of the WandB user that you would\nlike associated with the run. If not set, the default entity for the registered WandB API key is used.\n" + ), + ] = None + tags: Annotated[ + Optional[List[str]], + Field( + description='A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some\ndefault tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}".\n' + ), + ] = None + + +class Integration(BaseModel): + type: Annotated[ + Literal["wandb"], + Field( + description='The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported.\n' + ), + ] + wandb: Annotated[ + Wandb, + Field( + description="The settings for your integration with Weights and Biases. This payload specifies the project that\nmetrics will be sent to. Optionally, you can set an explicit display name for your run, add tags\nto your run, and set a default entity (team, username, etc) to be associated with your run.\n" + ), + ] + + +class CreateFineTuningJobRequest(BaseModel): + model: Annotated[ + Union[str, Literal["babbage-002", "davinci-002", "gpt-3.5-turbo", "gpt-4o-mini"]], + Field( + description="The name of the model to fine-tune. You can select one of the\n[supported models](/docs/guides/fine-tuning/which-models-can-be-fine-tuned).\n", + examples=["gpt-4o-mini"], + ), + ] + training_file: Annotated[ + str, + Field( + description="The ID of an uploaded file that contains training data.\n\nSee [upload file](/docs/api-reference/files/create) for how to upload a file.\n\nYour dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`.\n\nThe contents of the file should differ depending on if the model uses the [chat](/docs/api-reference/fine-tuning/chat-input) or [completions](/docs/api-reference/fine-tuning/completions-input) format.\n\nSee the [fine-tuning guide](/docs/guides/fine-tuning) for more details.\n", + examples=["file-abc123"], + ), + ] + hyperparameters: Annotated[ + Optional[Hyperparameters], + Field(description="The hyperparameters used for the fine-tuning job."), + ] = None + suffix: Annotated[ + Optional[str], + Field( + description='A string of up to 18 characters that will be added to your fine-tuned model name.\n\nFor example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-4o-mini:openai:custom-model-name:7p4lURel`.\n', + max_length=40, + min_length=1, + ), + ] = None + validation_file: Annotated[ + Optional[str], + Field( + description="The ID of an uploaded file that contains validation data.\n\nIf you provide this file, the data is used to generate validation\nmetrics periodically during fine-tuning. These metrics can be viewed in\nthe fine-tuning results file.\nThe same data should not be present in both train and validation files.\n\nYour dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`.\n\nSee the [fine-tuning guide](/docs/guides/fine-tuning) for more details.\n", + examples=["file-abc123"], + ), + ] = None + integrations: Annotated[ + Optional[List[Integration]], + Field(description="A list of integrations to enable for your fine-tuning job."), + ] = None + seed: Annotated[ + Optional[int], + Field( + description="The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases.\nIf a seed is not specified, one will be generated for you.\n", + examples=[42], + ge=0, + le=2147483647, + ), + ] = None + + +class Input(RootModel[List[str]]): + root: Annotated[ + List[str], + Field( + description="The array of strings that will be turned into an embedding.", + examples=["The quick brown fox jumped over the lazy dog"], + max_length=2048, + min_length=1, + title="array", + ), + ] + + +class Input1(RootModel[List[int]]): + root: Annotated[ + List[int], + Field( + description="The array of integers that will be turned into an embedding.", + examples=["[1212, 318, 257, 1332, 13]"], + max_length=2048, + min_length=1, + title="array", + ), + ] + + +class Input2Item(RootModel[List[int]]): + root: Annotated[List[int], Field(min_length=1)] + + +class Input2(RootModel[List[Input2Item]]): + root: Annotated[ + List[Input2Item], + Field( + description="The array of arrays containing integers that will be turned into an embedding.", + examples=["[[1212, 318, 257, 1332, 13]]"], + max_length=2048, + min_length=1, + title="array", + ), + ] + + +class CreateEmbeddingRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + input: Annotated[ + Union[str, Input, Input1, Input2], + Field( + description="Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + examples=["The quick brown fox jumped over the lazy dog"], + ), + ] + model: Annotated[ + Union[ + str, + Literal[ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + ], + ], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + examples=["text-embedding-3-small"], + ), + ] + encoding_format: Annotated[ + Literal["float", "base64"], + Field( + description="The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/).", + examples=["float"], + ), + ] = "float" + dimensions: Annotated[ + Optional[int], + Field( + description="The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models.\n", + ge=1, + ), + ] = None + user: Annotated[ + Optional[str], + Field( + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] = None + + +class Usage1(BaseModel): + prompt_tokens: Annotated[int, Field(description="The number of tokens used by the prompt.")] + total_tokens: Annotated[ + int, Field(description="The total number of tokens used by the request.") + ] + + +class CreateTranscriptionRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file: Annotated[ + bytes, + Field( + description="The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.\n" + ), + ] + model: Annotated[ + Union[str, Literal["whisper-1"]], + Field( + description="ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.\n", + examples=["whisper-1"], + ), + ] + language: Annotated[ + Optional[str], + Field( + description="The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency.\n" + ), + ] = None + prompt: Annotated[ + Optional[str], + Field( + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should match the audio language.\n" + ), + ] = None + response_format: Annotated[ + Literal["json", "text", "srt", "verbose_json", "vtt"], + Field( + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n" + ), + ] = "json" + temperature: Annotated[ + float, + Field( + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n" + ), + ] = 0 + timestamp_granularities__: Annotated[ + List[Literal["word", "segment"]], + Field( + alias="timestamp_granularities[]", + description="The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. Either or both of these options are supported: `word`, or `segment`. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency.\n", + ), + ] = ["segment"] + + +class CreateTranscriptionResponseJson(BaseModel): + text: Annotated[str, Field(description="The transcribed text.")] + + +class TranscriptionSegment(BaseModel): + id: Annotated[int, Field(description="Unique identifier of the segment.")] + seek: Annotated[int, Field(description="Seek offset of the segment.")] + start: Annotated[float, Field(description="Start time of the segment in seconds.")] + end: Annotated[float, Field(description="End time of the segment in seconds.")] + text: Annotated[str, Field(description="Text content of the segment.")] + tokens: Annotated[List[int], Field(description="Array of token IDs for the text content.")] + temperature: Annotated[ + float, + Field(description="Temperature parameter used for generating the segment."), + ] + avg_logprob: Annotated[ + float, + Field( + description="Average logprob of the segment. If the value is lower than -1, consider the logprobs failed." + ), + ] + compression_ratio: Annotated[ + float, + Field( + description="Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed." + ), + ] + no_speech_prob: Annotated[ + float, + Field( + description="Probability of no speech in the segment. If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this segment silent." + ), + ] + + +class TranscriptionWord(BaseModel): + word: Annotated[str, Field(description="The text content of the word.")] + start: Annotated[float, Field(description="Start time of the word in seconds.")] + end: Annotated[float, Field(description="End time of the word in seconds.")] + + +class CreateTranscriptionResponseVerboseJson(BaseModel): + language: Annotated[str, Field(description="The language of the input audio.")] + duration: Annotated[str, Field(description="The duration of the input audio.")] + text: Annotated[str, Field(description="The transcribed text.")] + words: Annotated[ + Optional[List[TranscriptionWord]], + Field(description="Extracted words and their corresponding timestamps."), + ] = None + segments: Annotated[ + Optional[List[TranscriptionSegment]], + Field(description="Segments of the transcribed text and their corresponding details."), + ] = None + + +class CreateTranslationRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file: Annotated[ + bytes, + Field( + description="The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm.\n" + ), + ] + model: Annotated[ + Union[str, Literal["whisper-1"]], + Field( + description="ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available.\n", + examples=["whisper-1"], + ), + ] + prompt: Annotated[ + Optional[str], + Field( + description="An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should be in English.\n" + ), + ] = None + response_format: Annotated[ + str, + Field( + description="The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`.\n" + ), + ] = "json" + temperature: Annotated[ + float, + Field( + description="The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit.\n" + ), + ] = 0 + + +class CreateTranslationResponseJson(BaseModel): + text: str + + +class CreateTranslationResponseVerboseJson(BaseModel): + language: Annotated[ + str, + Field(description="The language of the output translation (always `english`)."), + ] + duration: Annotated[str, Field(description="The duration of the input audio.")] + text: Annotated[str, Field(description="The translated text.")] + segments: Annotated[ + Optional[List[TranscriptionSegment]], + Field(description="Segments of the translated text and their corresponding details."), + ] = None + + +class CreateSpeechRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + model: Annotated[ + Union[str, Literal["tts-1", "tts-1-hd"]], + Field( + description="One of the available [TTS models](/docs/models/tts): `tts-1` or `tts-1-hd`\n" + ), + ] + input: Annotated[ + str, + Field( + description="The text to generate audio for. The maximum length is 4096 characters.", + max_length=4096, + ), + ] + voice: Annotated[ + Literal["alloy", "echo", "fable", "onyx", "nova", "shimmer"], + Field( + description="The voice to use when generating the audio. Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`. Previews of the voices are available in the [Text to speech guide](/docs/guides/text-to-speech/voice-options)." + ), + ] + response_format: Annotated[ + Literal["mp3", "opus", "aac", "flac", "wav", "pcm"], + Field( + description="The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`." + ), + ] = "mp3" + speed: Annotated[ + float, + Field( + description="The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default.", + ge=0.25, + le=4.0, + ), + ] = 1.0 + + +class Model(BaseModel): + id: Annotated[ + str, + Field(description="The model identifier, which can be referenced in the API endpoints."), + ] + created: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) when the model was created."), + ] + object: Annotated[ + Literal["model"], Field(description='The object type, which is always "model".') + ] + owned_by: Annotated[str, Field(description="The organization that owns the model.")] + + +class OpenAIFile(BaseModel): + id: Annotated[ + str, + Field(description="The file identifier, which can be referenced in the API endpoints."), + ] + bytes: Annotated[int, Field(description="The size of the file, in bytes.")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the file was created."), + ] + filename: Annotated[str, Field(description="The name of the file.")] + object: Annotated[ + Literal["file"], Field(description="The object type, which is always `file`.") + ] + purpose: Annotated[ + Literal[ + "assistants", + "assistants_output", + "batch", + "batch_output", + "fine-tune", + "fine-tune-results", + "vision", + ], + Field( + description="The intended purpose of the file. Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`, `fine-tune`, `fine-tune-results` and `vision`." + ), + ] + status: Annotated[ + Literal["uploaded", "processed", "error"], + Field( + description="Deprecated. The current status of the file, which can be either `uploaded`, `processed`, or `error`." + ), + ] + status_details: Annotated[ + Optional[str], + Field( + description="Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`." + ), + ] = None + + +class Upload(BaseModel): + id: Annotated[ + str, + Field( + description="The Upload unique identifier, which can be referenced in API endpoints." + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the Upload was created."), + ] + filename: Annotated[str, Field(description="The name of the file to be uploaded.")] + bytes: Annotated[int, Field(description="The intended number of bytes to be uploaded.")] + purpose: Annotated[ + str, + Field( + description="The intended purpose of the file. [Please refer here](/docs/api-reference/files/object#files/object-purpose) for acceptable values." + ), + ] + status: Annotated[ + Literal["pending", "completed", "cancelled", "expired"], + Field(description="The status of the Upload."), + ] + expires_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the Upload was created."), + ] + object: Annotated[ + Optional[Literal["upload"]], + Field(description='The object type, which is always "upload".'), + ] = None + file: Annotated[ + Optional[OpenAIFile], + Field(description="The ready File object after the Upload is completed."), + ] = None + + +class UploadPart(BaseModel): + id: Annotated[ + str, + Field( + description="The upload Part unique identifier, which can be referenced in API endpoints." + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the Part was created."), + ] + upload_id: Annotated[ + str, + Field(description="The ID of the Upload object that this Part was added to."), + ] + object: Annotated[ + Literal["upload.part"], + Field(description="The object type, which is always `upload.part`."), + ] + + +class Embedding(BaseModel): + index: Annotated[ + int, Field(description="The index of the embedding in the list of embeddings.") + ] + embedding: Annotated[ + List[float], + Field( + description="The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings).\n" + ), + ] + object: Annotated[ + Literal["embedding"], + Field(description='The object type, which is always "embedding".'), + ] + + +class Error1(BaseModel): + code: Annotated[str, Field(description="A machine-readable error code.")] + message: Annotated[str, Field(description="A human-readable error message.")] + param: Annotated[ + Optional[str], + Field( + description="The parameter that was invalid, usually `training_file` or `validation_file`. This field will be null if the failure was not parameter-specific." + ), + ] = None + + +class NEpochs1(RootModel[int]): + root: Annotated[ + int, + Field( + description='The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.\n"auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs.', + ge=1, + le=50, + ), + ] + + +class Hyperparameters1(BaseModel): + n_epochs: Annotated[ + Union[Literal["auto"], NEpochs1], + Field( + description='The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset.\n"auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs.' + ), + ] + + +class FineTuningIntegration(BaseModel): + type: Annotated[ + Literal["wandb"], + Field(description="The type of the integration being enabled for the fine-tuning job"), + ] + wandb: Annotated[ + Wandb, + Field( + description="The settings for your integration with Weights and Biases. This payload specifies the project that\nmetrics will be sent to. Optionally, you can set an explicit display name for your run, add tags\nto your run, and set a default entity (team, username, etc) to be associated with your run.\n" + ), + ] + + +class FineTuningJobEvent(BaseModel): + id: str + created_at: int + level: Literal["info", "warn", "error"] + message: str + object: Literal["fine_tuning.job.event"] + + +class Metrics(BaseModel): + step: Optional[float] = None + train_loss: Optional[float] = None + train_mean_token_accuracy: Optional[float] = None + valid_loss: Optional[float] = None + valid_mean_token_accuracy: Optional[float] = None + full_valid_loss: Optional[float] = None + full_valid_mean_token_accuracy: Optional[float] = None + + +class FineTuningJobCheckpoint(BaseModel): + id: Annotated[ + str, + Field( + description="The checkpoint identifier, which can be referenced in the API endpoints." + ), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the checkpoint was created."), + ] + fine_tuned_model_checkpoint: Annotated[ + str, + Field(description="The name of the fine-tuned checkpoint model that is created."), + ] + step_number: Annotated[ + int, Field(description="The step number that the checkpoint was created at.") + ] + metrics: Annotated[ + Metrics, + Field(description="Metrics at the step number during the fine-tuning job."), + ] + fine_tuning_job_id: Annotated[ + str, + Field(description="The name of the fine-tuning job that this checkpoint was created from."), + ] + object: Annotated[ + Literal["fine_tuning.job.checkpoint"], + Field(description='The object type, which is always "fine_tuning.job.checkpoint".'), + ] + + +class FinetuneCompletionRequestInput(BaseModel): + prompt: Annotated[ + Optional[str], Field(description="The input prompt for this training example.") + ] = None + completion: Annotated[ + Optional[str], + Field(description="The desired completion for this training example."), + ] = None + + +class CompletionUsage(BaseModel): + completion_tokens: Annotated[ + int, Field(description="Number of tokens in the generated completion.") + ] + prompt_tokens: Annotated[int, Field(description="Number of tokens in the prompt.")] + total_tokens: Annotated[ + int, + Field(description="Total number of tokens used in the request (prompt + completion)."), + ] + + +class RunCompletionUsage(BaseModel): + completion_tokens: Annotated[ + int, + Field(description="Number of completion tokens used over the course of the run."), + ] + prompt_tokens: Annotated[ + int, + Field(description="Number of prompt tokens used over the course of the run."), + ] + total_tokens: Annotated[ + int, Field(description="Total number of tokens used (prompt + completion).") + ] + + +class RunStepCompletionUsage(BaseModel): + completion_tokens: Annotated[ + int, + Field(description="Number of completion tokens used over the course of the run step."), + ] + prompt_tokens: Annotated[ + int, + Field(description="Number of prompt tokens used over the course of the run step."), + ] + total_tokens: Annotated[ + int, Field(description="Total number of tokens used (prompt + completion).") + ] + + +class AssistantsApiResponseFormatOption( + RootModel[ + Union[ + Literal["auto"], + ResponseFormatText, + ResponseFormatJsonObject, + ResponseFormatJsonSchema, + ] + ] +): + root: Annotated[ + Union[ + Literal["auto"], + ResponseFormatText, + ResponseFormatJsonObject, + ResponseFormatJsonSchema, + ], + Field( + description='Specifies the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4 Turbo](/docs/models/gpt-4-turbo-and-gpt-4), and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs).\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n' + ), + ] + + +class CodeInterpreter(BaseModel): + file_ids: Annotated[ + List[str], + Field( + description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool.\n", + max_length=20, + ), + ] = [] + + +class FileSearch(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] = None + + +class ToolResources(BaseModel): + code_interpreter: Optional[CodeInterpreter] = None + file_search: Optional[FileSearch] = None + + +class CodeInterpreter1(BaseModel): + file_ids: Annotated[ + List[str], + Field( + description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", + max_length=20, + ), + ] = [] + + +class ChunkingStrategy(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class Static(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + max_chunk_size_tokens: Annotated[ + int, + Field( + description="The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`.", + ge=100, + le=4096, + ), + ] + chunk_overlap_tokens: Annotated[ + int, + Field( + description="The number of tokens that overlap between chunks. The default value is `400`.\n\nNote that the overlap must not exceed half of `max_chunk_size_tokens`.\n" + ), + ] + + +class ChunkingStrategy1(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: Static + + +class VectorStore(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_length=10000, + ), + ] = None + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy, ChunkingStrategy1]], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class FileSearch1(BaseModel): + vector_store_ids: Annotated[ + List[str], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + vector_stores: Annotated[ + Optional[List[VectorStore]], + Field( + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] = None + + +class ChunkingStrategy2(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class ChunkingStrategy3(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: Static + + +class VectorStore1(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_length=10000, + ), + ] = None + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy2, ChunkingStrategy3]], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class FileSearch2(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] = None + vector_stores: Annotated[ + List[VectorStore1], + Field( + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] + + +class ToolResources1(BaseModel): + code_interpreter: Optional[CodeInterpreter1] = None + file_search: Optional[Union[FileSearch1, FileSearch2]] = None + + +class CodeInterpreter2(BaseModel): + file_ids: Annotated[ + List[str], + Field( + description="Overrides the list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", + max_length=20, + ), + ] = [] + + +class FileSearch3(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="Overrides the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] = None + + +class ToolResources2(BaseModel): + code_interpreter: Optional[CodeInterpreter2] = None + file_search: Optional[FileSearch3] = None + + +class DeleteAssistantResponse(BaseModel): + id: str + deleted: bool + object: Literal["assistant.deleted"] + + +class AssistantToolsCode(BaseModel): + type: Annotated[ + Literal["code_interpreter"], + Field(description="The type of tool being defined: `code_interpreter`"), + ] + + +class FileSearch4(BaseModel): + max_num_results: Annotated[ + Optional[int], + Field( + description="The maximum number of results the file search tool should output. The default is 20 for `gpt-4*` models and 5 for `gpt-3.5-turbo`. This number should be between 1 and 50 inclusive.\n\nNote that the file search tool may output fewer than `max_num_results` results. See the [file search tool documentation](/docs/assistants/tools/file-search/number-of-chunks-returned) for more information.\n", + ge=1, + le=50, + ), + ] = None + + +class AssistantToolsFileSearch(BaseModel): + type: Annotated[ + Literal["file_search"], + Field(description="The type of tool being defined: `file_search`"), + ] + file_search: Annotated[ + Optional[FileSearch4], Field(description="Overrides for the file search tool.") + ] = None + + +class AssistantToolsFileSearchTypeOnly(BaseModel): + type: Annotated[ + Literal["file_search"], + Field(description="The type of tool being defined: `file_search`"), + ] + + +class AssistantToolsFunction(BaseModel): + type: Annotated[ + Literal["function"], + Field(description="The type of tool being defined: `function`"), + ] + function: FunctionObject + + +class TruncationObject(BaseModel): + type: Annotated[ + Literal["auto", "last_messages"], + Field( + description="The truncation strategy to use for the thread. The default is `auto`. If set to `last_messages`, the thread will be truncated to the n most recent messages in the thread. When set to `auto`, messages in the middle of the thread will be dropped to fit the context length of the model, `max_prompt_tokens`." + ), + ] + last_messages: Annotated[ + Optional[int], + Field( + description="The number of most recent messages from the thread when constructing the context for the run.", + ge=1, + ), + ] = None + + +class Function3(BaseModel): + name: Annotated[str, Field(description="The name of the function to call.")] + + +class AssistantsNamedToolChoice(BaseModel): + type: Annotated[ + Literal["function", "code_interpreter", "file_search"], + Field( + description="The type of the tool. If type is `function`, the function name must be set" + ), + ] + function: Optional[Function3] = None + + +class LastError(BaseModel): + code: Annotated[ + Literal["server_error", "rate_limit_exceeded", "invalid_prompt"], + Field(description="One of `server_error`, `rate_limit_exceeded`, or `invalid_prompt`."), + ] + message: Annotated[str, Field(description="A human-readable description of the error.")] + + +class IncompleteDetails(BaseModel): + reason: Annotated[ + Optional[Literal["max_completion_tokens", "max_prompt_tokens"]], + Field( + description="The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run." + ), + ] = None + + +class ModifyRunRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class ToolOutput(BaseModel): + tool_call_id: Annotated[ + Optional[str], + Field( + description="The ID of the tool call in the `required_action` object within the run object the output is being submitted for." + ), + ] = None + output: Annotated[ + Optional[str], + Field(description="The output of the tool call to be submitted to continue the run."), + ] = None + + +class SubmitToolOutputsRunRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + tool_outputs: Annotated[ + List[ToolOutput], + Field(description="A list of tools for which the outputs are being submitted."), + ] + stream: Annotated[ + Optional[bool], + Field( + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n" + ), + ] = None + + +class Function4(BaseModel): + name: Annotated[str, Field(description="The name of the function.")] + arguments: Annotated[ + str, + Field(description="The arguments that the model expects you to pass to the function."), + ] + + +class RunToolCallObject(BaseModel): + id: Annotated[ + str, + Field( + description="The ID of the tool call. This ID must be referenced when you submit the tool outputs in using the [Submit tool outputs to run](/docs/api-reference/runs/submitToolOutputs) endpoint." + ), + ] + type: Annotated[ + Literal["function"], + Field( + description="The type of tool call the output is required for. For now, this is always `function`." + ), + ] + function: Annotated[Function4, Field(description="The function definition.")] + + +class CodeInterpreter3(BaseModel): + file_ids: Annotated[ + List[str], + Field( + description="A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool.\n", + max_length=20, + ), + ] = [] + + +class FileSearch5(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant.\n", + max_length=1, + ), + ] = None + + +class ToolResources3(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[FileSearch5] = None + + +class FileSearch6(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] = None + + +class ToolResources4(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[FileSearch6] = None + + +class ThreadObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["thread"], + Field(description="The object type, which is always `thread`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the thread was created."), + ] + tool_resources: Annotated[ + Optional[ToolResources4], + Field( + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + + +class ChunkingStrategy4(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class ChunkingStrategy5(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: Static + + +class VectorStore2(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_length=10000, + ), + ] = None + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy4, ChunkingStrategy5]], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class FileSearch7(BaseModel): + vector_store_ids: Annotated[ + List[str], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + vector_stores: Annotated[ + Optional[List[VectorStore2]], + Field( + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] = None + + +class ChunkingStrategy6(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class ChunkingStrategy7(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: Static + + +class VectorStore3(BaseModel): + file_ids: Annotated[ + Optional[List[str]], + Field( + description="A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store.\n", + max_length=10000, + ), + ] = None + chunking_strategy: Annotated[ + Optional[Union[ChunkingStrategy6, ChunkingStrategy7]], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class FileSearch8(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] = None + vector_stores: Annotated[ + List[VectorStore3], + Field( + description="A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] + + +class ToolResources5(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[Union[FileSearch7, FileSearch8]] = None + + +class FileSearch9(BaseModel): + vector_store_ids: Annotated[ + Optional[List[str]], + Field( + description="The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread.\n", + max_length=1, + ), + ] = None + + +class ToolResources6(BaseModel): + code_interpreter: Optional[CodeInterpreter3] = None + file_search: Optional[FileSearch9] = None + + +class ModifyThreadRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + tool_resources: Annotated[ + Optional[ToolResources6], + Field( + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class DeleteThreadResponse(BaseModel): + id: str + deleted: bool + object: Literal["thread.deleted"] + + +class ListThreadsResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[ThreadObject] + first_id: Annotated[str, Field(examples=["asst_abc123"])] + last_id: Annotated[str, Field(examples=["asst_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class IncompleteDetails1(BaseModel): + reason: Annotated[ + Literal["content_filter", "max_tokens", "run_cancelled", "run_expired", "run_failed"], + Field(description="The reason the message is incomplete."), + ] + + +class Attachment(BaseModel): + file_id: Annotated[ + Optional[str], Field(description="The ID of the file to attach to the message.") + ] = None + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearchTypeOnly]]], + Field(description="The tools to add this file to."), + ] = None + + +class ModifyMessageRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class DeleteMessageResponse(BaseModel): + id: str + deleted: bool + object: Literal["thread.message.deleted"] + + +class ImageFile(BaseModel): + file_id: Annotated[ + str, + Field( + description='The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content.' + ), + ] + detail: Annotated[ + Literal["auto", "low", "high"], + Field( + description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`." + ), + ] = "auto" + + +class MessageContentImageFileObject(BaseModel): + type: Annotated[Literal["image_file"], Field(description="Always `image_file`.")] + image_file: ImageFile + + +class ImageFile1(BaseModel): + file_id: Annotated[ + Optional[str], + Field( + description='The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content.' + ), + ] = None + detail: Annotated[ + Literal["auto", "low", "high"], + Field( + description="Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`." + ), + ] = "auto" + + +class MessageDeltaContentImageFileObject(BaseModel): + index: Annotated[int, Field(description="The index of the content part in the message.")] + type: Annotated[Literal["image_file"], Field(description="Always `image_file`.")] + image_file: Optional[ImageFile1] = None + + +class ImageUrl1(BaseModel): + url: Annotated[ + AnyUrl, + Field( + description="The external URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." + ), + ] + detail: Annotated[ + Literal["auto", "low", "high"], + Field( + description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto`" + ), + ] = "auto" + + +class MessageContentImageUrlObject(BaseModel): + type: Annotated[Literal["image_url"], Field(description="The type of the content part.")] + image_url: ImageUrl1 + + +class ImageUrl2(BaseModel): + url: Annotated[ + Optional[str], + Field( + description="The URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." + ), + ] = None + detail: Annotated[ + Literal["auto", "low", "high"], + Field( + description="Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`." + ), + ] = "auto" + + +class MessageDeltaContentImageUrlObject(BaseModel): + index: Annotated[int, Field(description="The index of the content part in the message.")] + type: Annotated[Literal["image_url"], Field(description="Always `image_url`.")] + image_url: Optional[ImageUrl2] = None + + +class MessageContentRefusalObject(BaseModel): + type: Annotated[Literal["refusal"], Field(description="Always `refusal`.")] + refusal: str + + +class MessageRequestContentTextObject(BaseModel): + type: Annotated[Literal["text"], Field(description="Always `text`.")] + text: Annotated[str, Field(description="Text content to be sent to the model")] + + +class FileCitation(BaseModel): + file_id: Annotated[str, Field(description="The ID of the specific File the citation is from.")] + + +class MessageContentTextAnnotationsFileCitationObject(BaseModel): + type: Annotated[Literal["file_citation"], Field(description="Always `file_citation`.")] + text: Annotated[ + str, + Field(description="The text in the message content that needs to be replaced."), + ] + file_citation: FileCitation + start_index: Annotated[int, Field(ge=0)] + end_index: Annotated[int, Field(ge=0)] + + +class FilePath(BaseModel): + file_id: Annotated[str, Field(description="The ID of the file that was generated.")] + + +class MessageContentTextAnnotationsFilePathObject(BaseModel): + type: Annotated[Literal["file_path"], Field(description="Always `file_path`.")] + text: Annotated[ + str, + Field(description="The text in the message content that needs to be replaced."), + ] + file_path: FilePath + start_index: Annotated[int, Field(ge=0)] + end_index: Annotated[int, Field(ge=0)] + + +class MessageDeltaContentRefusalObject(BaseModel): + index: Annotated[int, Field(description="The index of the refusal part in the message.")] + type: Annotated[Literal["refusal"], Field(description="Always `refusal`.")] + refusal: Optional[str] = None + + +class FileCitation1(BaseModel): + file_id: Annotated[ + Optional[str], + Field(description="The ID of the specific File the citation is from."), + ] = None + quote: Annotated[Optional[str], Field(description="The specific quote in the file.")] = None + + +class MessageDeltaContentTextAnnotationsFileCitationObject(BaseModel): + index: Annotated[ + int, Field(description="The index of the annotation in the text content part.") + ] + type: Annotated[Literal["file_citation"], Field(description="Always `file_citation`.")] + text: Annotated[ + Optional[str], + Field(description="The text in the message content that needs to be replaced."), + ] = None + file_citation: Optional[FileCitation1] = None + start_index: Annotated[Optional[int], Field(ge=0)] = None + end_index: Annotated[Optional[int], Field(ge=0)] = None + + +class FilePath1(BaseModel): + file_id: Annotated[ + Optional[str], Field(description="The ID of the file that was generated.") + ] = None + + +class MessageDeltaContentTextAnnotationsFilePathObject(BaseModel): + index: Annotated[ + int, Field(description="The index of the annotation in the text content part.") + ] + type: Annotated[Literal["file_path"], Field(description="Always `file_path`.")] + text: Annotated[ + Optional[str], + Field(description="The text in the message content that needs to be replaced."), + ] = None + file_path: Optional[FilePath1] = None + start_index: Annotated[Optional[int], Field(ge=0)] = None + end_index: Annotated[Optional[int], Field(ge=0)] = None + + +class LastError1(BaseModel): + code: Annotated[ + Literal["server_error", "rate_limit_exceeded"], + Field(description="One of `server_error` or `rate_limit_exceeded`."), + ] + message: Annotated[str, Field(description="A human-readable description of the error.")] + + +class MessageCreation(BaseModel): + message_id: Annotated[ + str, + Field(description="The ID of the message that was created by this run step."), + ] + + +class RunStepDetailsMessageCreationObject(BaseModel): + type: Annotated[Literal["message_creation"], Field(description="Always `message_creation`.")] + message_creation: MessageCreation + + +class MessageCreation1(BaseModel): + message_id: Annotated[ + Optional[str], + Field(description="The ID of the message that was created by this run step."), + ] = None + + +class RunStepDeltaStepDetailsMessageCreationObject(BaseModel): + type: Annotated[Literal["message_creation"], Field(description="Always `message_creation`.")] + message_creation: Optional[MessageCreation1] = None + + +class RunStepDetailsToolCallsCodeOutputLogsObject(BaseModel): + type: Annotated[Literal["logs"], Field(description="Always `logs`.")] + logs: Annotated[str, Field(description="The text output from the Code Interpreter tool call.")] + + +class RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject(BaseModel): + index: Annotated[int, Field(description="The index of the output in the outputs array.")] + type: Annotated[Literal["logs"], Field(description="Always `logs`.")] + logs: Annotated[ + Optional[str], + Field(description="The text output from the Code Interpreter tool call."), + ] = None + + +class Image1(BaseModel): + file_id: Annotated[ + str, Field(description="The [file](/docs/api-reference/files) ID of the image.") + ] + + +class RunStepDetailsToolCallsCodeOutputImageObject(BaseModel): + type: Annotated[Literal["image"], Field(description="Always `image`.")] + image: Image1 + + +class Image2(BaseModel): + file_id: Annotated[ + Optional[str], + Field(description="The [file](/docs/api-reference/files) ID of the image."), + ] = None + + +class RunStepDeltaStepDetailsToolCallsCodeOutputImageObject(BaseModel): + index: Annotated[int, Field(description="The index of the output in the outputs array.")] + type: Annotated[Literal["image"], Field(description="Always `image`.")] + image: Optional[Image2] = None + + +class RunStepDetailsToolCallsFileSearchObject(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call object.")] + type: Annotated[ + Literal["file_search"], + Field( + description="The type of tool call. This is always going to be `file_search` for this type of tool call." + ), + ] + file_search: Annotated[ + Dict[str, Any], + Field(description="For now, this is always going to be an empty object."), + ] + + +class RunStepDeltaStepDetailsToolCallsFileSearchObject(BaseModel): + index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call object.")] = None + type: Annotated[ + Literal["file_search"], + Field( + description="The type of tool call. This is always going to be `file_search` for this type of tool call." + ), + ] + file_search: Annotated[ + Dict[str, Any], + Field(description="For now, this is always going to be an empty object."), + ] + + +class Function5(BaseModel): + name: Annotated[str, Field(description="The name of the function.")] + arguments: Annotated[str, Field(description="The arguments passed to the function.")] + output: Annotated[ + Optional[str], + Field( + description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet." + ), + ] = None + + +class RunStepDetailsToolCallsFunctionObject(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call object.")] + type: Annotated[ + Literal["function"], + Field( + description="The type of tool call. This is always going to be `function` for this type of tool call." + ), + ] + function: Annotated[ + Function5, Field(description="The definition of the function that was called.") + ] + + +class Function6(BaseModel): + name: Annotated[Optional[str], Field(description="The name of the function.")] = None + arguments: Annotated[ + Optional[str], Field(description="The arguments passed to the function.") + ] = None + output: Annotated[ + Optional[str], + Field( + description="The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet." + ), + ] = None + + +class RunStepDeltaStepDetailsToolCallsFunctionObject(BaseModel): + index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call object.")] = None + type: Annotated[ + Literal["function"], + Field( + description="The type of tool call. This is always going to be `function` for this type of tool call." + ), + ] + function: Annotated[ + Optional[Function6], + Field(description="The definition of the function that was called."), + ] = None + + +class VectorStoreExpirationAfter(BaseModel): + anchor: Annotated[ + Literal["last_active_at"], + Field( + description="Anchor timestamp after which the expiration policy applies. Supported anchors: `last_active_at`." + ), + ] + days: Annotated[ + int, + Field( + description="The number of days after the anchor time that the vector store will expire.", + ge=1, + le=365, + ), + ] + + +class FileCounts(BaseModel): + in_progress: Annotated[ + int, + Field(description="The number of files that are currently being processed."), + ] + completed: Annotated[ + int, + Field(description="The number of files that have been successfully processed."), + ] + failed: Annotated[int, Field(description="The number of files that have failed to process.")] + cancelled: Annotated[int, Field(description="The number of files that were cancelled.")] + total: Annotated[int, Field(description="The total number of files.")] + + +class VectorStoreObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["vector_store"], + Field(description="The object type, which is always `vector_store`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the vector store was created."), + ] + name: Annotated[str, Field(description="The name of the vector store.")] + usage_bytes: Annotated[ + int, + Field(description="The total number of bytes used by the files in the vector store."), + ] + file_counts: FileCounts + status: Annotated[ + Literal["expired", "in_progress", "completed"], + Field( + description="The status of the vector store, which can be either `expired`, `in_progress`, or `completed`. A status of `completed` indicates that the vector store is ready for use." + ), + ] + expires_after: Optional[VectorStoreExpirationAfter] = None + expires_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the vector store will expire."), + ] = None + last_active_at: Annotated[ + Optional[int], + Field( + description="The Unix timestamp (in seconds) for when the vector store was last active." + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + + +class UpdateVectorStoreRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + name: Annotated[Optional[str], Field(description="The name of the vector store.")] = None + expires_after: Optional[VectorStoreExpirationAfter] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class ListVectorStoresResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[VectorStoreObject] + first_id: Annotated[str, Field(examples=["vs_abc123"])] + last_id: Annotated[str, Field(examples=["vs_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class DeleteVectorStoreResponse(BaseModel): + id: str + deleted: bool + object: Literal["vector_store.deleted"] + + +class LastError2(BaseModel): + code: Annotated[ + Literal["server_error", "unsupported_file", "invalid_file"], + Field(description="One of `server_error` or `rate_limit_exceeded`."), + ] + message: Annotated[str, Field(description="A human-readable description of the error.")] + + +class OtherChunkingStrategyResponseParam(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["other"], Field(description="Always `other`.")] + + +class StaticChunkingStrategy(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + max_chunk_size_tokens: Annotated[ + int, + Field( + description="The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`.", + ge=100, + le=4096, + ), + ] + chunk_overlap_tokens: Annotated[ + int, + Field( + description="The number of tokens that overlap between chunks. The default value is `400`.\n\nNote that the overlap must not exceed half of `max_chunk_size_tokens`.\n" + ), + ] + + +class AutoChunkingStrategyRequestParam(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["auto"], Field(description="Always `auto`.")] + + +class StaticChunkingStrategyRequestParam(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: StaticChunkingStrategy + + +class ChunkingStrategyRequestParam( + RootModel[Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam]] +): + root: Annotated[ + Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy." + ), + ] + + +class CreateVectorStoreFileRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file_id: Annotated[ + str, + Field( + description="A [File](/docs/api-reference/files) ID that the vector store should use. Useful for tools like `file_search` that can access files." + ), + ] + chunking_strategy: Optional[ChunkingStrategyRequestParam] = None + + +class DeleteVectorStoreFileResponse(BaseModel): + id: str + deleted: bool + object: Literal["vector_store.file.deleted"] + + +class FileCounts1(BaseModel): + in_progress: Annotated[ + int, + Field(description="The number of files that are currently being processed."), + ] + completed: Annotated[int, Field(description="The number of files that have been processed.")] + failed: Annotated[int, Field(description="The number of files that have failed to process.")] + cancelled: Annotated[int, Field(description="The number of files that where cancelled.")] + total: Annotated[int, Field(description="The total number of files.")] + + +class VectorStoreFileBatchObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["vector_store.files_batch"], + Field(description="The object type, which is always `vector_store.file_batch`."), + ] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the vector store files batch was created." + ), + ] + vector_store_id: Annotated[ + str, + Field( + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to." + ), + ] + status: Annotated[ + Literal["in_progress", "completed", "cancelled", "failed"], + Field( + description="The status of the vector store files batch, which can be either `in_progress`, `completed`, `cancelled` or `failed`." + ), + ] + file_counts: FileCounts1 + + +class CreateVectorStoreFileBatchRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file_ids: Annotated[ + List[str], + Field( + description="A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files.", + max_length=500, + min_length=1, + ), + ] + chunking_strategy: Optional[ChunkingStrategyRequestParam] = None + + +class ThreadStreamEvent1(BaseModel): + event: Literal["thread.created"] + data: ThreadObject + + +class ThreadStreamEvent(RootModel[ThreadStreamEvent1]): + root: ThreadStreamEvent1 + + +class ErrorEvent(BaseModel): + event: Literal["error"] + data: Error + + +class DoneEvent(BaseModel): + event: Literal["done"] + data: Literal["[DONE]"] + + +class Datum(BaseModel): + code: Annotated[ + Optional[str], Field(description="An error code identifying the error type.") + ] = None + message: Annotated[ + Optional[str], + Field(description="A human-readable message providing more details about the error."), + ] = None + param: Annotated[ + Optional[str], + Field(description="The name of the parameter that caused the error, if applicable."), + ] = None + line: Annotated[ + Optional[int], + Field( + description="The line number of the input file where the error occurred, if applicable." + ), + ] = None + + +class Errors(BaseModel): + object: Annotated[ + Optional[str], Field(description="The object type, which is always `list`.") + ] = None + data: Optional[List[Datum]] = None + + +class RequestCounts(BaseModel): + total: Annotated[int, Field(description="Total number of requests in the batch.")] + completed: Annotated[ + int, + Field(description="Number of requests that have been completed successfully."), + ] + failed: Annotated[int, Field(description="Number of requests that have failed.")] + + +class Batch(BaseModel): + id: str + object: Annotated[ + Literal["batch"], Field(description="The object type, which is always `batch`.") + ] + endpoint: Annotated[str, Field(description="The OpenAI API endpoint used by the batch.")] + errors: Optional[Errors] = None + input_file_id: Annotated[str, Field(description="The ID of the input file for the batch.")] + completion_window: Annotated[ + str, + Field(description="The time frame within which the batch should be processed."), + ] + status: Annotated[ + Literal[ + "validating", + "failed", + "in_progress", + "finalizing", + "completed", + "expired", + "cancelling", + "cancelled", + ], + Field(description="The current status of the batch."), + ] + output_file_id: Annotated[ + Optional[str], + Field( + description="The ID of the file containing the outputs of successfully executed requests." + ), + ] = None + error_file_id: Annotated[ + Optional[str], + Field(description="The ID of the file containing the outputs of requests with errors."), + ] = None + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the batch was created."), + ] + in_progress_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch started processing."), + ] = None + expires_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch will expire."), + ] = None + finalizing_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch started finalizing."), + ] = None + completed_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch was completed."), + ] = None + failed_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch failed."), + ] = None + expired_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch expired."), + ] = None + cancelling_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch started cancelling."), + ] = None + cancelled_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the batch was cancelled."), + ] = None + request_counts: Annotated[ + Optional[RequestCounts], + Field(description="The request counts for different statuses within the batch."), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class BatchRequestInput(BaseModel): + custom_id: Annotated[ + Optional[str], + Field( + description="A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch." + ), + ] = None + method: Annotated[ + Optional[Literal["POST"]], + Field( + description="The HTTP method to be used for the request. Currently only `POST` is supported." + ), + ] = None + url: Annotated[ + Optional[str], + Field( + description="The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported." + ), + ] = None + + +class Response(BaseModel): + status_code: Annotated[ + Optional[int], Field(description="The HTTP status code of the response") + ] = None + request_id: Annotated[ + Optional[str], + Field( + description="An unique identifier for the OpenAI API request. Please include this request ID when contacting support." + ), + ] = None + body: Annotated[ + Optional[Dict[str, Any]], Field(description="The JSON body of the response") + ] = None + + +class Error2(BaseModel): + code: Annotated[Optional[str], Field(description="A machine-readable error code.")] = None + message: Annotated[Optional[str], Field(description="A human-readable error message.")] = None + + +class BatchRequestOutput(BaseModel): + id: Optional[str] = None + custom_id: Annotated[ + Optional[str], + Field( + description="A developer-provided per-request id that will be used to match outputs to inputs." + ), + ] = None + response: Optional[Response] = None + error: Annotated[ + Optional[Error2], + Field( + description="For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure." + ), + ] = None + + +class ListBatchesResponse(BaseModel): + data: List[Batch] + first_id: Annotated[Optional[str], Field(examples=["batch_abc123"])] = None + last_id: Annotated[Optional[str], Field(examples=["batch_abc456"])] = None + has_more: bool + object: Literal["list"] + + +class AuditLogActorServiceAccount(BaseModel): + id: Annotated[Optional[str], Field(description="The service account id.")] = None + + +class AuditLogActorUser(BaseModel): + id: Annotated[Optional[str], Field(description="The user id.")] = None + email: Annotated[Optional[str], Field(description="The user email.")] = None + + +class AuditLogActorApiKey(BaseModel): + id: Annotated[Optional[str], Field(description="The tracking id of the API key.")] = None + type: Annotated[ + Optional[Literal["user", "service_account"]], + Field(description="The type of API key. Can be either `user` or `service_account`."), + ] = None + user: Optional[AuditLogActorUser] = None + service_account: Optional[AuditLogActorServiceAccount] = None + + +class AuditLogActorSession(BaseModel): + user: Optional[AuditLogActorUser] = None + ip_address: Annotated[ + Optional[str], + Field(description="The IP address from which the action was performed."), + ] = None + + +class AuditLogActor(BaseModel): + type: Annotated[ + Optional[Literal["session", "api_key"]], + Field(description="The type of actor. Is either `session` or `api_key`."), + ] = None + session: Optional[AuditLogActorSession] = None + api_key: Optional[AuditLogActorApiKey] = None + + +class AuditLogEventType( + RootModel[ + Literal[ + "api_key.created", + "api_key.updated", + "api_key.deleted", + "invite.sent", + "invite.accepted", + "invite.deleted", + "login.succeeded", + "login.failed", + "logout.succeeded", + "logout.failed", + "organization.updated", + "project.created", + "project.updated", + "project.archived", + "service_account.created", + "service_account.updated", + "service_account.deleted", + "user.added", + "user.updated", + "user.deleted", + ] + ] +): + root: Annotated[ + Literal[ + "api_key.created", + "api_key.updated", + "api_key.deleted", + "invite.sent", + "invite.accepted", + "invite.deleted", + "login.succeeded", + "login.failed", + "logout.succeeded", + "logout.failed", + "organization.updated", + "project.created", + "project.updated", + "project.archived", + "service_account.created", + "service_account.updated", + "service_account.deleted", + "user.added", + "user.updated", + "user.deleted", + ], + Field(description="The event type."), + ] + + +class Project(BaseModel): + id: Annotated[Optional[str], Field(description="The project ID.")] = None + name: Annotated[Optional[str], Field(description="The project title.")] = None + + +class Data(BaseModel): + scopes: Annotated[ + Optional[List[str]], + Field(description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`'), + ] = None + + +class ApiKeyCreated(BaseModel): + id: Annotated[Optional[str], Field(description="The tracking ID of the API key.")] = None + data: Annotated[ + Optional[Data], Field(description="The payload used to create the API key.") + ] = None + + +class ChangesRequested(BaseModel): + scopes: Annotated[ + Optional[List[str]], + Field(description='A list of scopes allowed for the API key, e.g. `["api.model.request"]`'), + ] = None + + +class ApiKeyUpdated(BaseModel): + id: Annotated[Optional[str], Field(description="The tracking ID of the API key.")] = None + changes_requested: Annotated[ + Optional[ChangesRequested], + Field(description="The payload used to update the API key."), + ] = None + + +class ApiKeyDeleted(BaseModel): + id: Annotated[Optional[str], Field(description="The tracking ID of the API key.")] = None + + +class Data1(BaseModel): + email: Annotated[Optional[str], Field(description="The email invited to the organization.")] = ( + None + ) + role: Annotated[ + Optional[str], + Field(description="The role the email was invited to be. Is either `owner` or `member`."), + ] = None + + +class InviteSent(BaseModel): + id: Annotated[Optional[str], Field(description="The ID of the invite.")] = None + data: Annotated[ + Optional[Data1], Field(description="The payload used to create the invite.") + ] = None + + +class InviteAccepted(BaseModel): + id: Annotated[Optional[str], Field(description="The ID of the invite.")] = None + + +class InviteDeleted(BaseModel): + id: Annotated[Optional[str], Field(description="The ID of the invite.")] = None + + +class LoginFailed(BaseModel): + error_code: Annotated[Optional[str], Field(description="The error code of the failure.")] = None + error_message: Annotated[ + Optional[str], Field(description="The error message of the failure.") + ] = None + + +class LogoutFailed(BaseModel): + error_code: Annotated[Optional[str], Field(description="The error code of the failure.")] = None + error_message: Annotated[ + Optional[str], Field(description="The error message of the failure.") + ] = None + + +class Settings(BaseModel): + threads_ui_visibility: Annotated[ + Optional[str], + Field( + description="Visibility of the threads page which shows messages created with the Assistants API and Playground. One of `ANY_ROLE`, `OWNERS`, or `NONE`." + ), + ] = None + usage_dashboard_visibility: Annotated[ + Optional[str], + Field( + description="Visibility of the usage dashboard which shows activity and costs for your organization. One of `ANY_ROLE` or `OWNERS`." + ), + ] = None + + +class ChangesRequested1(BaseModel): + title: Annotated[Optional[str], Field(description="The organization title.")] = None + description: Annotated[Optional[str], Field(description="The organization description.")] = None + name: Annotated[Optional[str], Field(description="The organization name.")] = None + settings: Optional[Settings] = None + + +class OrganizationUpdated(BaseModel): + id: Annotated[Optional[str], Field(description="The organization ID.")] = None + changes_requested: Annotated[ + Optional[ChangesRequested1], + Field(description="The payload used to update the organization settings."), + ] = None + + +class Data2(BaseModel): + name: Annotated[Optional[str], Field(description="The project name.")] = None + title: Annotated[ + Optional[str], + Field(description="The title of the project as seen on the dashboard."), + ] = None + + +class ProjectCreated(BaseModel): + id: Annotated[Optional[str], Field(description="The project ID.")] = None + data: Annotated[ + Optional[Data2], Field(description="The payload used to create the project.") + ] = None + + +class ChangesRequested2(BaseModel): + title: Annotated[ + Optional[str], + Field(description="The title of the project as seen on the dashboard."), + ] = None + + +class ProjectUpdated(BaseModel): + id: Annotated[Optional[str], Field(description="The project ID.")] = None + changes_requested: Annotated[ + Optional[ChangesRequested2], + Field(description="The payload used to update the project."), + ] = None + + +class ProjectArchived(BaseModel): + id: Annotated[Optional[str], Field(description="The project ID.")] = None + + +class Data3(BaseModel): + role: Annotated[ + Optional[str], + Field(description="The role of the service account. Is either `owner` or `member`."), + ] = None + + +class ServiceAccountCreated(BaseModel): + id: Annotated[Optional[str], Field(description="The service account ID.")] = None + data: Annotated[ + Optional[Data3], + Field(description="The payload used to create the service account."), + ] = None + + +class ChangesRequested3(BaseModel): + role: Annotated[ + Optional[str], + Field(description="The role of the service account. Is either `owner` or `member`."), + ] = None + + +class ServiceAccountUpdated(BaseModel): + id: Annotated[Optional[str], Field(description="The service account ID.")] = None + changes_requested: Annotated[ + Optional[ChangesRequested3], + Field(description="The payload used to updated the service account."), + ] = None + + +class ServiceAccountDeleted(BaseModel): + id: Annotated[Optional[str], Field(description="The service account ID.")] = None + + +class Data4(BaseModel): + role: Annotated[ + Optional[str], + Field(description="The role of the user. Is either `owner` or `member`."), + ] = None + + +class UserAdded(BaseModel): + id: Annotated[Optional[str], Field(description="The user ID.")] = None + data: Annotated[ + Optional[Data4], + Field(description="The payload used to add the user to the project."), + ] = None + + +class ChangesRequested4(BaseModel): + role: Annotated[ + Optional[str], + Field(description="The role of the user. Is either `owner` or `member`."), + ] = None + + +class UserUpdated(BaseModel): + id: Annotated[Optional[str], Field(description="The project ID.")] = None + changes_requested: Annotated[ + Optional[ChangesRequested4], + Field(description="The payload used to update the user."), + ] = None + + +class UserDeleted(BaseModel): + id: Annotated[Optional[str], Field(description="The user ID.")] = None + + +class AuditLog(BaseModel): + id: Annotated[str, Field(description="The ID of this log.")] + type: AuditLogEventType + effective_at: Annotated[int, Field(description="The Unix timestamp (in seconds) of the event.")] + project: Annotated[ + Optional[Project], + Field( + description="The project that the action was scoped to. Absent for actions not scoped to projects." + ), + ] = None + actor: AuditLogActor + api_key_created: Annotated[ + Optional[ApiKeyCreated], + Field( + alias="api_key.created", + description="The details for events with this `type`.", + ), + ] = None + api_key_updated: Annotated[ + Optional[ApiKeyUpdated], + Field( + alias="api_key.updated", + description="The details for events with this `type`.", + ), + ] = None + api_key_deleted: Annotated[ + Optional[ApiKeyDeleted], + Field( + alias="api_key.deleted", + description="The details for events with this `type`.", + ), + ] = None + invite_sent: Annotated[ + Optional[InviteSent], + Field(alias="invite.sent", description="The details for events with this `type`."), + ] = None + invite_accepted: Annotated[ + Optional[InviteAccepted], + Field( + alias="invite.accepted", + description="The details for events with this `type`.", + ), + ] = None + invite_deleted: Annotated[ + Optional[InviteDeleted], + Field( + alias="invite.deleted", + description="The details for events with this `type`.", + ), + ] = None + login_failed: Annotated[ + Optional[LoginFailed], + Field(alias="login.failed", description="The details for events with this `type`."), + ] = None + logout_failed: Annotated[ + Optional[LogoutFailed], + Field( + alias="logout.failed", + description="The details for events with this `type`.", + ), + ] = None + organization_updated: Annotated[ + Optional[OrganizationUpdated], + Field( + alias="organization.updated", + description="The details for events with this `type`.", + ), + ] = None + project_created: Annotated[ + Optional[ProjectCreated], + Field( + alias="project.created", + description="The details for events with this `type`.", + ), + ] = None + project_updated: Annotated[ + Optional[ProjectUpdated], + Field( + alias="project.updated", + description="The details for events with this `type`.", + ), + ] = None + project_archived: Annotated[ + Optional[ProjectArchived], + Field( + alias="project.archived", + description="The details for events with this `type`.", + ), + ] = None + service_account_created: Annotated[ + Optional[ServiceAccountCreated], + Field( + alias="service_account.created", + description="The details for events with this `type`.", + ), + ] = None + service_account_updated: Annotated[ + Optional[ServiceAccountUpdated], + Field( + alias="service_account.updated", + description="The details for events with this `type`.", + ), + ] = None + service_account_deleted: Annotated[ + Optional[ServiceAccountDeleted], + Field( + alias="service_account.deleted", + description="The details for events with this `type`.", + ), + ] = None + user_added: Annotated[ + Optional[UserAdded], + Field(alias="user.added", description="The details for events with this `type`."), + ] = None + user_updated: Annotated[ + Optional[UserUpdated], + Field(alias="user.updated", description="The details for events with this `type`."), + ] = None + user_deleted: Annotated[ + Optional[UserDeleted], + Field(alias="user.deleted", description="The details for events with this `type`."), + ] = None + + +class ListAuditLogsResponse(BaseModel): + object: Literal["list"] + data: List[AuditLog] + first_id: Annotated[str, Field(examples=["audit_log-defb456h8dks"])] + last_id: Annotated[str, Field(examples=["audit_log-hnbkd8s93s"])] + has_more: bool + + +class Invite(BaseModel): + object: Annotated[ + Literal["organization.invite"], + Field(description="The object type, which is always `organization.invite`"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + email: Annotated[ + str, + Field(description="The email address of the individual to whom the invite was sent"), + ] + role: Annotated[Literal["owner", "reader"], Field(description="`owner` or `reader`")] + status: Annotated[ + Literal["accepted", "expired", "pending"], + Field(description="`accepted`,`expired`, or `pending`"), + ] + invited_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the invite was sent."), + ] + expires_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the invite expires."), + ] + accepted_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) of when the invite was accepted."), + ] = None + + +class InviteListResponse(BaseModel): + object: Annotated[Literal["list"], Field(description="The object type, which is always `list`")] + data: List[Invite] + first_id: Annotated[ + Optional[str], + Field(description="The first `invite_id` in the retrieved `list`"), + ] = None + last_id: Annotated[ + Optional[str], Field(description="The last `invite_id` in the retrieved `list`") + ] = None + has_more: Annotated[ + Optional[bool], + Field( + description="The `has_more` property is used for pagination to indicate there are additional results." + ), + ] = None + + +class InviteRequest(BaseModel): + email: Annotated[str, Field(description="Send an email to this address")] + role: Annotated[Literal["reader", "owner"], Field(description="`owner` or `reader`")] + + +class InviteDeleteResponse(BaseModel): + object: Annotated[ + Literal["organization.invite.deleted"], + Field(description="The object type, which is always `organization.invite.deleted`"), + ] + id: str + deleted: bool + + +class User(BaseModel): + object: Annotated[ + Literal["organization.user"], + Field(description="The object type, which is always `organization.user`"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + name: Annotated[str, Field(description="The name of the user")] + email: Annotated[str, Field(description="The email address of the user")] + role: Annotated[Literal["owner", "reader"], Field(description="`owner` or `reader`")] + added_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the user was added."), + ] + + +class UserListResponse(BaseModel): + object: Literal["list"] + data: List[User] + first_id: str + last_id: str + has_more: bool + + +class UserRoleUpdateRequest(BaseModel): + role: Annotated[Literal["owner", "reader"], Field(description="`owner` or `reader`")] + + +class UserDeleteResponse(BaseModel): + object: Literal["organization.user.deleted"] + id: str + deleted: bool + + +class Project1(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + object: Annotated[ + Literal["organization.project"], + Field(description="The object type, which is always `organization.project`"), + ] + name: Annotated[str, Field(description="The name of the project. This appears in reporting.")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the project was created."), + ] + archived_at: Annotated[ + Optional[int], + Field( + description="The Unix timestamp (in seconds) of when the project was archived or `null`." + ), + ] = None + status: Annotated[Literal["active", "archived"], Field(description="`active` or `archived`")] + + +class ProjectListResponse(BaseModel): + object: Literal["list"] + data: List[Project1] + first_id: str + last_id: str + has_more: bool + + +class ProjectCreateRequest(BaseModel): + name: Annotated[ + str, + Field(description="The friendly name of the project, this name appears in reports."), + ] + + +class ProjectUpdateRequest(BaseModel): + name: Annotated[ + str, + Field(description="The updated name of the project, this name appears in reports."), + ] + + +class DefaultProjectErrorResponse(BaseModel): + code: int + message: str + + +class ProjectUser(BaseModel): + object: Annotated[ + Literal["organization.project.user"], + Field(description="The object type, which is always `organization.project.user`"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + name: Annotated[str, Field(description="The name of the user")] + email: Annotated[str, Field(description="The email address of the user")] + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + added_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the project was added."), + ] + + +class ProjectUserListResponse(BaseModel): + object: str + data: List[ProjectUser] + first_id: str + last_id: str + has_more: bool + + +class ProjectUserCreateRequest(BaseModel): + user_id: Annotated[str, Field(description="The ID of the user.")] + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + + +class ProjectUserUpdateRequest(BaseModel): + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + + +class ProjectUserDeleteResponse(BaseModel): + object: Literal["organization.project.user.deleted"] + id: str + deleted: bool + + +class ProjectServiceAccount(BaseModel): + object: Annotated[ + Literal["organization.project.service_account"], + Field( + description="The object type, which is always `organization.project.service_account`" + ), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + name: Annotated[str, Field(description="The name of the service account")] + role: Annotated[Literal["owner", "member"], Field(description="`owner` or `member`")] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the service account was created" + ), + ] + + +class ProjectServiceAccountListResponse(BaseModel): + object: Literal["list"] + data: List[ProjectServiceAccount] + first_id: str + last_id: str + has_more: bool + + +class ProjectServiceAccountCreateRequest(BaseModel): + name: Annotated[str, Field(description="The name of the service account being created.")] + + +class ProjectServiceAccountApiKey(BaseModel): + object: Annotated[ + Literal["organization.project.service_account.api_key"], + Field( + description="The object type, which is always `organization.project.service_account.api_key`" + ), + ] + value: str + name: str + created_at: int + id: str + + +class ProjectServiceAccountDeleteResponse(BaseModel): + object: Literal["organization.project.service_account.deleted"] + id: str + deleted: bool + + +class Owner(BaseModel): + type: Annotated[ + Optional[Literal["user", "service_account"]], + Field(description="`user` or `service_account`"), + ] = None + user: Optional[ProjectUser] = None + service_account: Optional[ProjectServiceAccount] = None + + +class ProjectApiKey(BaseModel): + object: Annotated[ + Literal["organization.project.api_key"], + Field(description="The object type, which is always `organization.project.api_key`"), + ] + redacted_value: Annotated[str, Field(description="The redacted value of the API key")] + name: Annotated[str, Field(description="The name of the API key")] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the API key was created"), + ] + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints"), + ] + owner: Owner + + +class ProjectApiKeyListResponse(BaseModel): + object: Literal["list"] + data: List[ProjectApiKey] + first_id: str + last_id: str + has_more: bool + + +class ProjectApiKeyDeleteResponse(BaseModel): + object: Literal["organization.project.api_key.deleted"] + id: str + deleted: bool + + +class ListModelsResponse(BaseModel): + object: Literal["list"] + data: List[Model] + + +class CreateCompletionRequest(BaseModel): + model: Annotated[ + Union[str, Literal["gpt-3.5-turbo-instruct", "davinci-002", "babbage-002"]], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n" + ), + ] + prompt: Annotated[ + Optional[Union[Optional[str], List[str], Prompt, Prompt1]], + Field( + description="The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays.\n\nNote that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document.\n" + ), + ] + best_of: Annotated[ + Optional[int], + Field( + description='Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed.\n\nWhen used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n', + ge=0, + le=20, + ), + ] = 1 + echo: Annotated[ + Optional[bool], + Field(description="Echo back the prompt in addition to the completion\n"), + ] = False + frequency_penalty: Annotated[ + Optional[float], + Field( + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] = 0 + logit_bias: Annotated[ + Optional[Dict[str, int]], + Field( + description='Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n\nAs an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated.\n' + ), + ] = None + logprobs: Annotated[ + Optional[int], + Field( + description="Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response.\n\nThe maximum value for `logprobs` is 5.\n", + ge=0, + le=5, + ), + ] = None + max_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of [tokens](/tokenizer) that can be generated in the completion.\n\nThe token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n", + examples=[16], + ge=0, + ), + ] = 16 + n: Annotated[ + Optional[int], + Field( + description="How many completions to generate for each prompt.\n\n**Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`.\n", + examples=[1], + ge=1, + le=128, + ), + ] = 1 + presence_penalty: Annotated[ + Optional[float], + Field( + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] = 0 + seed: Annotated[ + Optional[int], + Field( + description="If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\n\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", + ge=-9223372036854775808, + le=9223372036854775807, + ), + ] = None + stop: Annotated[ + Optional[Union[Optional[str], Stop]], + Field( + description="Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence.\n" + ), + ] = None + stream: Annotated[ + Optional[bool], + Field( + description="Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n" + ), + ] = False + stream_options: Optional[ChatCompletionStreamOptions] = None + suffix: Annotated[ + Optional[str], + Field( + description="The suffix that comes after a completion of inserted text.\n\nThis parameter is only supported for `gpt-3.5-turbo-instruct`.\n", + examples=["test."], + ), + ] = None + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] = 1 + user: Annotated[ + Optional[str], + Field( + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] = None + + +class CreateCompletionResponse(BaseModel): + id: Annotated[str, Field(description="A unique identifier for the completion.")] + choices: Annotated[ + List[Choice], + Field( + description="The list of completion choices the model generated for the input prompt." + ), + ] + created: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) of when the completion was created."), + ] + model: Annotated[str, Field(description="The model used for completion.")] + system_fingerprint: Annotated[ + Optional[str], + Field( + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" + ), + ] = None + object: Annotated[ + Literal["text_completion"], + Field(description='The object type, which is always "text_completion"'), + ] + usage: Optional[CompletionUsage] = None + + +class ChatCompletionTool(BaseModel): + type: Annotated[ + Literal["function"], + Field(description="The type of the tool. Currently, only `function` is supported."), + ] + function: FunctionObject + + +class ChatCompletionToolChoiceOption( + RootModel[Union[Literal["none", "auto", "required"], ChatCompletionNamedToolChoice]] +): + root: Annotated[ + Union[Literal["none", "auto", "required"], ChatCompletionNamedToolChoice], + Field( + description='Controls which (if any) tool is called by the model.\n`none` means the model will not call any tool and instead generates a message.\n`auto` means the model can pick between generating a message or calling one or more tools.\n`required` means the model must call one or more tools.\nSpecifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.\n\n`none` is the default when no tools are present. `auto` is the default if tools are present.\n' + ), + ] + + +class ChatCompletionMessageToolCalls(RootModel[List[ChatCompletionMessageToolCall]]): + root: Annotated[ + List[ChatCompletionMessageToolCall], + Field(description="The tool calls generated by the model, such as function calls."), + ] + + +class ChatCompletionResponseMessage(BaseModel): + content: Annotated[Optional[str], Field(description="The contents of the message.")] = None + refusal: Annotated[ + Optional[str], Field(description="The refusal message generated by the model.") + ] = None + tool_calls: Optional[ChatCompletionMessageToolCalls] = None + role: Annotated[ + Literal["assistant"], + Field(description="The role of the author of this message."), + ] + function_call: Annotated[ + Optional[FunctionCall], + Field( + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." + ), + ] = None + + +class Choice1(BaseModel): + finish_reason: Annotated[ + Literal["stop", "length", "tool_calls", "content_filter", "function_call"], + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence,\n`length` if the maximum number of tokens specified in the request was reached,\n`content_filter` if content was omitted due to a flag from our content filters,\n`tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function.\n" + ), + ] + index: Annotated[int, Field(description="The index of the choice in the list of choices.")] + message: ChatCompletionResponseMessage + logprobs: Annotated[ + Optional[Logprobs2], + Field(description="Log probability information for the choice."), + ] = None + + +class CreateChatCompletionResponse(BaseModel): + id: Annotated[str, Field(description="A unique identifier for the chat completion.")] + choices: Annotated[ + List[Choice1], + Field( + description="A list of chat completion choices. Can be more than one if `n` is greater than 1." + ), + ] + created: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the chat completion was created." + ), + ] + model: Annotated[str, Field(description="The model used for the chat completion.")] + service_tier: Annotated[ + Optional[Literal["scale", "default"]], + Field( + description="The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request.", + examples=["scale"], + ), + ] = None + system_fingerprint: Annotated[ + Optional[str], + Field( + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" + ), + ] = None + object: Annotated[ + Literal["chat.completion"], + Field(description="The object type, which is always `chat.completion`."), + ] + usage: Optional[CompletionUsage] = None + + +class Choice2(BaseModel): + finish_reason: Annotated[ + Literal["stop", "length", "function_call", "content_filter"], + Field( + description="The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, or `function_call` if the model called a function.\n" + ), + ] + index: Annotated[int, Field(description="The index of the choice in the list of choices.")] + message: ChatCompletionResponseMessage + + +class CreateChatCompletionFunctionResponse(BaseModel): + id: Annotated[str, Field(description="A unique identifier for the chat completion.")] + choices: Annotated[ + List[Choice2], + Field( + description="A list of chat completion choices. Can be more than one if `n` is greater than 1." + ), + ] + created: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) of when the chat completion was created." + ), + ] + model: Annotated[str, Field(description="The model used for the chat completion.")] + system_fingerprint: Annotated[ + Optional[str], + Field( + description="This fingerprint represents the backend configuration that the model runs with.\n\nCan be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism.\n" + ), + ] = None + object: Annotated[ + Literal["chat.completion"], + Field(description="The object type, which is always `chat.completion`."), + ] + usage: Optional[CompletionUsage] = None + + +class ImagesResponse(BaseModel): + created: int + data: List[Image] + + +class ListFilesResponse(BaseModel): + data: List[OpenAIFile] + object: Literal["list"] + + +class ListFineTuningJobEventsResponse(BaseModel): + data: List[FineTuningJobEvent] + object: Literal["list"] + + +class ListFineTuningJobCheckpointsResponse(BaseModel): + data: List[FineTuningJobCheckpoint] + object: Literal["list"] + first_id: Optional[str] = None + last_id: Optional[str] = None + has_more: bool + + +class CreateEmbeddingResponse(BaseModel): + data: Annotated[ + List[Embedding], + Field(description="The list of embeddings generated by the model."), + ] + model: Annotated[ + str, Field(description="The name of the model used to generate the embedding.") + ] + object: Annotated[ + Literal["list"], Field(description='The object type, which is always "list".') + ] + usage: Annotated[Usage1, Field(description="The usage information for the request.")] + + +class FineTuningJob(BaseModel): + id: Annotated[ + str, + Field(description="The object identifier, which can be referenced in the API endpoints."), + ] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the fine-tuning job was created." + ), + ] + error: Annotated[ + Optional[Error1], + Field( + description="For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure." + ), + ] + fine_tuned_model: Annotated[ + Optional[str], + Field( + description="The name of the fine-tuned model that is being created. The value will be null if the fine-tuning job is still running." + ), + ] = None + finished_at: Annotated[ + Optional[int], + Field( + description="The Unix timestamp (in seconds) for when the fine-tuning job was finished. The value will be null if the fine-tuning job is still running." + ), + ] = None + hyperparameters: Annotated[ + Hyperparameters1, + Field( + description="The hyperparameters used for the fine-tuning job. See the [fine-tuning guide](/docs/guides/fine-tuning) for more details." + ), + ] + model: Annotated[str, Field(description="The base model that is being fine-tuned.")] + object: Annotated[ + Literal["fine_tuning.job"], + Field(description='The object type, which is always "fine_tuning.job".'), + ] + organization_id: Annotated[ + str, Field(description="The organization that owns the fine-tuning job.") + ] + result_files: Annotated[ + List[str], + Field( + description="The compiled results file ID(s) for the fine-tuning job. You can retrieve the results with the [Files API](/docs/api-reference/files/retrieve-contents)." + ), + ] + status: Annotated[ + Literal["validating_files", "queued", "running", "succeeded", "failed", "cancelled"], + Field( + description="The current status of the fine-tuning job, which can be either `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`." + ), + ] + trained_tokens: Annotated[ + Optional[int], + Field( + description="The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running." + ), + ] = None + training_file: Annotated[ + str, + Field( + description="The file ID used for training. You can retrieve the training data with the [Files API](/docs/api-reference/files/retrieve-contents)." + ), + ] + validation_file: Annotated[ + Optional[str], + Field( + description="The file ID used for validation. You can retrieve the validation results with the [Files API](/docs/api-reference/files/retrieve-contents)." + ), + ] = None + integrations: Annotated[ + Optional[List[FineTuningIntegration]], + Field( + description="A list of integrations to enable for this fine-tuning job.", + max_length=5, + ), + ] = None + seed: Annotated[int, Field(description="The seed used for the fine-tuning job.")] + estimated_finish: Annotated[ + Optional[int], + Field( + description="The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running." + ), + ] = None + + +class AssistantObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["assistant"], + Field(description="The object type, which is always `assistant`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the assistant was created."), + ] + name: Annotated[ + Optional[str], + Field( + description="The name of the assistant. The maximum length is 256 characters.\n", + max_length=256, + ), + ] = None + description: Annotated[ + Optional[str], + Field( + description="The description of the assistant. The maximum length is 512 characters.\n", + max_length=512, + ), + ] = None + model: Annotated[ + str, + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n" + ), + ] + instructions: Annotated[ + Optional[str], + Field( + description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", + max_length=256000, + ), + ] = None + tools: Annotated[ + List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], + Field( + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", + max_length=128, + ), + ] + tool_resources: Annotated[ + Optional[ToolResources], + Field( + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] = 1 + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class CreateAssistantRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + model: Annotated[ + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n", + examples=["gpt-4o"], + ), + ] + name: Annotated[ + Optional[str], + Field( + description="The name of the assistant. The maximum length is 256 characters.\n", + max_length=256, + ), + ] = None + description: Annotated[ + Optional[str], + Field( + description="The description of the assistant. The maximum length is 512 characters.\n", + max_length=512, + ), + ] = None + instructions: Annotated[ + Optional[str], + Field( + description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", + max_length=256000, + ), + ] = None + tools: Annotated[ + List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], + Field( + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", + max_length=128, + ), + ] = [] + tool_resources: Annotated[ + Optional[ToolResources1], + Field( + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] = 1 + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class ModifyAssistantRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + model: Annotated[ + Optional[str], + Field( + description="ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them.\n" + ), + ] = None + name: Annotated[ + Optional[str], + Field( + description="The name of the assistant. The maximum length is 256 characters.\n", + max_length=256, + ), + ] = None + description: Annotated[ + Optional[str], + Field( + description="The description of the assistant. The maximum length is 512 characters.\n", + max_length=512, + ), + ] = None + instructions: Annotated[ + Optional[str], + Field( + description="The system instructions that the assistant uses. The maximum length is 256,000 characters.\n", + max_length=256000, + ), + ] = None + tools: Annotated[ + List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], + Field( + description="A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`.\n", + max_length=128, + ), + ] = [] + tool_resources: Annotated[ + Optional[ToolResources2], + Field( + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] = 1 + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class ListAssistantsResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[AssistantObject] + first_id: Annotated[str, Field(examples=["asst_abc123"])] + last_id: Annotated[str, Field(examples=["asst_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class AssistantsApiToolChoiceOption( + RootModel[Union[Literal["none", "auto", "required"], AssistantsNamedToolChoice]] +): + root: Annotated[ + Union[Literal["none", "auto", "required"], AssistantsNamedToolChoice], + Field( + description='Controls which (if any) tool is called by the model.\n`none` means the model will not call any tools and instead generates a message.\n`auto` is the default value and means the model can pick between generating a message or calling one or more tools.\n`required` means the model must call one or more tools before responding to the user.\nSpecifying a particular tool like `{"type": "file_search"}` or `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool.\n' + ), + ] + + +class SubmitToolOutputs(BaseModel): + tool_calls: Annotated[ + List[RunToolCallObject], Field(description="A list of the relevant tool calls.") + ] + + +class RequiredAction(BaseModel): + type: Annotated[ + Literal["submit_tool_outputs"], + Field(description="For now, this is always `submit_tool_outputs`."), + ] + submit_tool_outputs: Annotated[ + SubmitToolOutputs, + Field(description="Details on the tool outputs needed for this run to continue."), + ] + + +class RunObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["thread.run"], + Field(description="The object type, which is always `thread.run`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run was created."), + ] + thread_id: Annotated[ + str, + Field( + description="The ID of the [thread](/docs/api-reference/threads) that was executed on as a part of this run." + ), + ] + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) used for execution of this run." + ), + ] + status: Annotated[ + Literal[ + "queued", + "in_progress", + "requires_action", + "cancelling", + "cancelled", + "failed", + "completed", + "incomplete", + "expired", + ], + Field( + description="The status of the run, which can be either `queued`, `in_progress`, `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, `incomplete`, or `expired`." + ), + ] + required_action: Annotated[ + Optional[RequiredAction], + Field( + description="Details on the action required to continue the run. Will be `null` if no action is required." + ), + ] + last_error: Annotated[ + Optional[LastError], + Field( + description="The last error associated with this run. Will be `null` if there are no errors." + ), + ] + expires_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the run will expire."), + ] = None + started_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the run was started."), + ] = None + cancelled_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the run was cancelled."), + ] = None + failed_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the run failed."), + ] = None + completed_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the run was completed."), + ] = None + incomplete_details: Annotated[ + Optional[IncompleteDetails], + Field( + description="Details on why the run is incomplete. Will be `null` if the run is not incomplete." + ), + ] + model: Annotated[ + str, + Field( + description="The model that the [assistant](/docs/api-reference/assistants) used for this run." + ), + ] + instructions: Annotated[ + str, + Field( + description="The instructions that the [assistant](/docs/api-reference/assistants) used for this run." + ), + ] + tools: Annotated[ + List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]], + Field( + description="The list of tools that the [assistant](/docs/api-reference/assistants) used for this run.", + max_length=20, + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + usage: RunCompletionUsage + temperature: Annotated[ + Optional[float], + Field(description="The sampling temperature used for this run. If not set, defaults to 1."), + ] = None + top_p: Annotated[ + Optional[float], + Field( + description="The nucleus sampling value used for this run. If not set, defaults to 1." + ), + ] = None + max_prompt_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of prompt tokens specified to have been used over the course of the run.\n", + ge=256, + ), + ] = None + max_completion_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of completion tokens specified to have been used over the course of the run.\n", + ge=256, + ), + ] = None + truncation_strategy: Annotated[Optional[TruncationObject], Field(...)] + tool_choice: Annotated[Optional[AssistantsApiToolChoiceOption], Field(...)] + parallel_tool_calls: ParallelToolCalls + response_format: Annotated[Optional[AssistantsApiResponseFormatOption], Field(...)] + + +class ListRunsResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[RunObject] + first_id: Annotated[str, Field(examples=["run_abc123"])] + last_id: Annotated[str, Field(examples=["run_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class Content4( + RootModel[ + List[ + Union[ + MessageContentImageFileObject, + MessageContentImageUrlObject, + MessageRequestContentTextObject, + ] + ] + ] +): + root: Annotated[ + List[ + Union[ + MessageContentImageFileObject, + MessageContentImageUrlObject, + MessageRequestContentTextObject, + ] + ], + Field( + description="An array of content parts with a defined type, each can be of type `text` or images can be passed with `image_url` or `image_file`. Image types are only supported on [Vision-compatible models](/docs/models/overview).", + min_length=1, + title="Array of content parts", + ), + ] + + +class CreateMessageRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + role: Annotated[ + Literal["user", "assistant"], + Field( + description="The role of the entity that is creating the message. Allowed values include:\n- `user`: Indicates the message is sent by an actual user and should be used in most cases to represent user-generated messages.\n- `assistant`: Indicates the message is generated by the assistant. Use this value to insert messages from the assistant into the conversation.\n" + ), + ] + content: Union[str, Content4] + attachments: Annotated[ + Optional[List[Attachment]], + Field( + description="A list of files attached to the message, and the tools they should be added to." + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class Text(BaseModel): + value: Annotated[str, Field(description="The data that makes up the text.")] + annotations: List[ + Union[ + MessageContentTextAnnotationsFileCitationObject, + MessageContentTextAnnotationsFilePathObject, + ] + ] + + +class MessageContentTextObject(BaseModel): + type: Annotated[Literal["text"], Field(description="Always `text`.")] + text: Text + + +class Text1(BaseModel): + value: Annotated[Optional[str], Field(description="The data that makes up the text.")] = None + annotations: Optional[ + List[ + Union[ + MessageDeltaContentTextAnnotationsFileCitationObject, + MessageDeltaContentTextAnnotationsFilePathObject, + ] + ] + ] = None + + +class MessageDeltaContentTextObject(BaseModel): + index: Annotated[int, Field(description="The index of the content part in the message.")] + type: Annotated[Literal["text"], Field(description="Always `text`.")] + text: Optional[Text1] = None + + +class CodeInterpreter7(BaseModel): + input: Annotated[str, Field(description="The input to the Code Interpreter tool call.")] + outputs: Annotated[ + List[ + Union[ + RunStepDetailsToolCallsCodeOutputLogsObject, + RunStepDetailsToolCallsCodeOutputImageObject, + ] + ], + Field( + description="The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type." + ), + ] + + +class RunStepDetailsToolCallsCodeObject(BaseModel): + id: Annotated[str, Field(description="The ID of the tool call.")] + type: Annotated[ + Literal["code_interpreter"], + Field( + description="The type of tool call. This is always going to be `code_interpreter` for this type of tool call." + ), + ] + code_interpreter: Annotated[ + CodeInterpreter7, + Field(description="The Code Interpreter tool call definition."), + ] + + +class CodeInterpreter8(BaseModel): + input: Annotated[ + Optional[str], Field(description="The input to the Code Interpreter tool call.") + ] = None + outputs: Annotated[ + Optional[ + List[ + Union[ + RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject, + RunStepDeltaStepDetailsToolCallsCodeOutputImageObject, + ] + ] + ], + Field( + description="The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type." + ), + ] = None + + +class RunStepDeltaStepDetailsToolCallsCodeObject(BaseModel): + index: Annotated[int, Field(description="The index of the tool call in the tool calls array.")] + id: Annotated[Optional[str], Field(description="The ID of the tool call.")] = None + type: Annotated[ + Literal["code_interpreter"], + Field( + description="The type of tool call. This is always going to be `code_interpreter` for this type of tool call." + ), + ] + code_interpreter: Annotated[ + Optional[CodeInterpreter8], + Field(description="The Code Interpreter tool call definition."), + ] = None + + +class CreateVectorStoreRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + file_ids: Annotated[ + Optional[List[str]], + Field( + description="A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files.", + max_length=500, + ), + ] = None + name: Annotated[Optional[str], Field(description="The name of the vector store.")] = None + expires_after: Optional[VectorStoreExpirationAfter] = None + chunking_strategy: Annotated[ + Optional[Union[AutoChunkingStrategyRequestParam, StaticChunkingStrategyRequestParam]], + Field( + description="The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. Only applicable if `file_ids` is non-empty." + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class StaticChunkingStrategyResponseParam(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + type: Annotated[Literal["static"], Field(description="Always `static`.")] + static: StaticChunkingStrategy + + +class RunStreamEvent1(BaseModel): + event: Literal["thread.run.created"] + data: RunObject + + +class RunStreamEvent2(BaseModel): + event: Literal["thread.run.queued"] + data: RunObject + + +class RunStreamEvent3(BaseModel): + event: Literal["thread.run.in_progress"] + data: RunObject + + +class RunStreamEvent4(BaseModel): + event: Literal["thread.run.requires_action"] + data: RunObject + + +class RunStreamEvent5(BaseModel): + event: Literal["thread.run.completed"] + data: RunObject + + +class RunStreamEvent6(BaseModel): + event: Literal["thread.run.incomplete"] + data: RunObject + + +class RunStreamEvent7(BaseModel): + event: Literal["thread.run.failed"] + data: RunObject + + +class RunStreamEvent8(BaseModel): + event: Literal["thread.run.cancelling"] + data: RunObject + + +class RunStreamEvent9(BaseModel): + event: Literal["thread.run.cancelled"] + data: RunObject + + +class RunStreamEvent10(BaseModel): + event: Literal["thread.run.expired"] + data: RunObject + + +class RunStreamEvent( + RootModel[ + Union[ + RunStreamEvent1, + RunStreamEvent2, + RunStreamEvent3, + RunStreamEvent4, + RunStreamEvent5, + RunStreamEvent6, + RunStreamEvent7, + RunStreamEvent8, + RunStreamEvent9, + RunStreamEvent10, + ] + ] +): + root: Union[ + RunStreamEvent1, + RunStreamEvent2, + RunStreamEvent3, + RunStreamEvent4, + RunStreamEvent5, + RunStreamEvent6, + RunStreamEvent7, + RunStreamEvent8, + RunStreamEvent9, + RunStreamEvent10, + ] + + +class ProjectServiceAccountCreateResponse(BaseModel): + object: Literal["organization.project.service_account"] + id: str + name: str + role: Annotated[ + Literal["member"], + Field(description="Service accounts can only have one role of type `member`"), + ] + created_at: int + api_key: ProjectServiceAccountApiKey + + +class ChatCompletionRequestAssistantMessage(BaseModel): + content: Annotated[ + Optional[Union[Optional[str], Content2]], + Field( + description="The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified.\n" + ), + ] = None + refusal: Annotated[ + Optional[str], Field(description="The refusal message by the assistant.") + ] = None + role: Annotated[ + Literal["assistant"], + Field(description="The role of the messages author, in this case `assistant`."), + ] + name: Annotated[ + Optional[str], + Field( + description="An optional name for the participant. Provides the model information to differentiate between participants of the same role." + ), + ] = None + tool_calls: Optional[ChatCompletionMessageToolCalls] = None + function_call: Annotated[ + Optional[FunctionCall], + Field( + description="Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." + ), + ] = None + + +class FineTuneChatCompletionRequestAssistantMessage(ChatCompletionRequestAssistantMessage): + weight: Annotated[ + Optional[Literal[0, 1]], + Field(description="Controls whether the assistant message is trained against (0 or 1)"), + ] = None + role: Annotated[ + Literal["assistant"], + Field(description="The role of the messages author, in this case `assistant`."), + ] + + +class ListPaginatedFineTuningJobsResponse(BaseModel): + data: List[FineTuningJob] + has_more: bool + object: Literal["list"] + + +class FinetuneChatRequestInput(BaseModel): + messages: Annotated[ + Optional[ + List[ + Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + FineTuneChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ] + ] + ], + Field(min_length=1), + ] = None + tools: Annotated[ + Optional[List[ChatCompletionTool]], + Field(description="A list of tools the model may generate JSON inputs for."), + ] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + functions: Annotated[ + Optional[List[ChatCompletionFunctions]], + Field( + description="A list of functions the model may generate JSON inputs for.", + max_length=128, + min_length=1, + ), + ] = None + + +class CreateRunRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run." + ), + ] + model: Annotated[ + Optional[ + Union[ + Optional[str], + Literal[ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ] + ], + Field( + description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", + examples=["gpt-4o"], + ), + ] = None + instructions: Annotated[ + Optional[str], + Field( + description="Overrides the [instructions](/docs/api-reference/assistants/createAssistant) of the assistant. This is useful for modifying the behavior on a per-run basis." + ), + ] = None + additional_instructions: Annotated[ + Optional[str], + Field( + description="Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions." + ), + ] = None + additional_messages: Annotated[ + Optional[List[CreateMessageRequest]], + Field(description="Adds additional messages to the thread before creating the run."), + ] = None + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", + max_length=20, + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] = 1 + stream: Annotated[ + Optional[bool], + Field( + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n" + ), + ] = None + max_prompt_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] = None + max_completion_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] = None + truncation_strategy: Optional[TruncationObject] = None + tool_choice: Optional[AssistantsApiToolChoiceOption] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class CreateThreadRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + messages: Annotated[ + Optional[List[CreateMessageRequest]], + Field( + description="A list of [messages](/docs/api-reference/messages) to start the thread with." + ), + ] = None + tool_resources: Annotated[ + Optional[ToolResources5], + Field( + description="A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + + +class MessageObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["thread.message"], + Field(description="The object type, which is always `thread.message`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the message was created."), + ] + thread_id: Annotated[ + str, + Field( + description="The [thread](/docs/api-reference/threads) ID that this message belongs to." + ), + ] + status: Annotated[ + Literal["in_progress", "incomplete", "completed"], + Field( + description="The status of the message, which can be either `in_progress`, `incomplete`, or `completed`." + ), + ] + incomplete_details: Annotated[ + Optional[IncompleteDetails1], + Field(description="On an incomplete message, details about why the message is incomplete."), + ] + completed_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the message was completed."), + ] = None + incomplete_at: Annotated[ + Optional[int], + Field( + description="The Unix timestamp (in seconds) for when the message was marked as incomplete." + ), + ] = None + role: Annotated[ + Literal["user", "assistant"], + Field(description="The entity that produced the message. One of `user` or `assistant`."), + ] + content: Annotated[ + List[ + Union[ + MessageContentImageFileObject, + MessageContentImageUrlObject, + MessageContentTextObject, + MessageContentRefusalObject, + ] + ], + Field(description="The content of the message in array of text and/or images."), + ] + assistant_id: Annotated[ + Optional[str], + Field( + description="If applicable, the ID of the [assistant](/docs/api-reference/assistants) that authored this message." + ), + ] = None + run_id: Annotated[ + Optional[str], + Field( + description="The ID of the [run](/docs/api-reference/runs) associated with the creation of this message. Value is `null` when messages are created manually using the create message or create thread endpoints." + ), + ] = None + attachments: Annotated[ + Optional[List[Attachment]], + Field( + description="A list of files attached to the message, and the tools they were added to." + ), + ] + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + + +class Delta(BaseModel): + role: Annotated[ + Optional[Literal["user", "assistant"]], + Field(description="The entity that produced the message. One of `user` or `assistant`."), + ] = None + content: Annotated[ + Optional[ + List[ + Union[ + MessageDeltaContentImageFileObject, + MessageDeltaContentTextObject, + MessageDeltaContentRefusalObject, + MessageDeltaContentImageUrlObject, + ] + ] + ], + Field(description="The content of the message in array of text and/or images."), + ] = None + + +class MessageDeltaObject(BaseModel): + id: Annotated[ + str, + Field( + description="The identifier of the message, which can be referenced in API endpoints." + ), + ] + object: Annotated[ + Literal["thread.message.delta"], + Field(description="The object type, which is always `thread.message.delta`."), + ] + delta: Annotated[ + Delta, + Field(description="The delta containing the fields that have changed on the Message."), + ] + + +class ListMessagesResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[MessageObject] + first_id: Annotated[str, Field(examples=["msg_abc123"])] + last_id: Annotated[str, Field(examples=["msg_abc123"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class RunStepDetailsToolCallsObject(BaseModel): + type: Annotated[Literal["tool_calls"], Field(description="Always `tool_calls`.")] + tool_calls: Annotated[ + List[ + Union[ + RunStepDetailsToolCallsCodeObject, + RunStepDetailsToolCallsFileSearchObject, + RunStepDetailsToolCallsFunctionObject, + ] + ], + Field( + description="An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`.\n" + ), + ] + + +class RunStepDeltaStepDetailsToolCallsObject(BaseModel): + type: Annotated[Literal["tool_calls"], Field(description="Always `tool_calls`.")] + tool_calls: Annotated[ + Optional[ + List[ + Union[ + RunStepDeltaStepDetailsToolCallsCodeObject, + RunStepDeltaStepDetailsToolCallsFileSearchObject, + RunStepDeltaStepDetailsToolCallsFunctionObject, + ] + ] + ], + Field( + description="An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`.\n" + ), + ] = None + + +class VectorStoreFileObject(BaseModel): + id: Annotated[ + str, + Field(description="The identifier, which can be referenced in API endpoints."), + ] + object: Annotated[ + Literal["vector_store.file"], + Field(description="The object type, which is always `vector_store.file`."), + ] + usage_bytes: Annotated[ + int, + Field( + description="The total vector store usage in bytes. Note that this may be different from the original file size." + ), + ] + created_at: Annotated[ + int, + Field( + description="The Unix timestamp (in seconds) for when the vector store file was created." + ), + ] + vector_store_id: Annotated[ + str, + Field( + description="The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to." + ), + ] + status: Annotated[ + Literal["in_progress", "completed", "cancelled", "failed"], + Field( + description="The status of the vector store file, which can be either `in_progress`, `completed`, `cancelled`, or `failed`. The status `completed` indicates that the vector store file is ready for use." + ), + ] + last_error: Annotated[ + Optional[LastError2], + Field( + description="The last error associated with this vector store file. Will be `null` if there are no errors." + ), + ] + chunking_strategy: Annotated[ + Optional[Union[StaticChunkingStrategyResponseParam, OtherChunkingStrategyResponseParam]], + Field(description="The strategy used to chunk the file."), + ] = None + + +class ListVectorStoreFilesResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[VectorStoreFileObject] + first_id: Annotated[str, Field(examples=["file-abc123"])] + last_id: Annotated[str, Field(examples=["file-abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class MessageStreamEvent1(BaseModel): + event: Literal["thread.message.created"] + data: MessageObject + + +class MessageStreamEvent2(BaseModel): + event: Literal["thread.message.in_progress"] + data: MessageObject + + +class MessageStreamEvent3(BaseModel): + event: Literal["thread.message.delta"] + data: MessageDeltaObject + + +class MessageStreamEvent4(BaseModel): + event: Literal["thread.message.completed"] + data: MessageObject + + +class MessageStreamEvent5(BaseModel): + event: Literal["thread.message.incomplete"] + data: MessageObject + + +class MessageStreamEvent( + RootModel[ + Union[ + MessageStreamEvent1, + MessageStreamEvent2, + MessageStreamEvent3, + MessageStreamEvent4, + MessageStreamEvent5, + ] + ] +): + root: Union[ + MessageStreamEvent1, + MessageStreamEvent2, + MessageStreamEvent3, + MessageStreamEvent4, + MessageStreamEvent5, + ] + + +class ChatCompletionRequestMessage( + RootModel[ + Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + ChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ] + ] +): + root: Annotated[ + Union[ + ChatCompletionRequestSystemMessage, + ChatCompletionRequestUserMessage, + ChatCompletionRequestAssistantMessage, + ChatCompletionRequestToolMessage, + ChatCompletionRequestFunctionMessage, + ], + Field(discriminator="role"), + ] + + +class CreateChatCompletionRequest(BaseModel): + messages: Annotated[ + List[ChatCompletionRequestMessage], + Field( + description="A list of messages comprising the conversation so far. [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models).", + min_length=1, + ), + ] + model: Annotated[ + Union[ + str, + Literal[ + "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "chatgpt-4o-latest", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ], + Field( + description="ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API.", + examples=["gpt-4o"], + ), + ] + frequency_penalty: Annotated[ + Optional[float], + Field( + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] = 0 + logit_bias: Annotated[ + Optional[Dict[str, int]], + Field( + description="Modify the likelihood of specified tokens appearing in the completion.\n\nAccepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token.\n" + ), + ] = None + logprobs: Annotated[ + Optional[bool], + Field( + description="Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`." + ), + ] = False + top_logprobs: Annotated[ + Optional[int], + Field( + description="An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used.", + ge=0, + le=20, + ), + ] = None + max_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of [tokens](/tokenizer) that can be generated in the chat completion.\n\nThe total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens.\n" + ), + ] = None + n: Annotated[ + Optional[int], + Field( + description="How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep `n` as `1` to minimize costs.", + examples=[1], + ge=1, + le=128, + ), + ] = 1 + presence_penalty: Annotated[ + Optional[float], + Field( + description="Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.\n\n[See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details)\n", + ge=-2.0, + le=2.0, + ), + ] = 0 + response_format: Annotated[ + Optional[Union[ResponseFormatText, ResponseFormatJsonObject, ResponseFormatJsonSchema]], + Field( + description='An object specifying the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4o mini](/docs/models/gpt-4o-mini), [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`.\n\nSetting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs).\n\nSetting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON.\n\n**Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length.\n' + ), + ] = None + seed: Annotated[ + Optional[int], + Field( + description="This feature is in Beta.\nIf specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result.\nDeterminism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend.\n", + ge=-9223372036854775808, + le=9223372036854775807, + ), + ] = None + service_tier: Annotated[ + Optional[Literal["auto", "default"]], + Field( + description="Specifies the latency tier to use for processing the request. This parameter is relevant for customers subscribed to the scale tier service:\n - If set to 'auto', the system will utilize scale tier credits until they are exhausted.\n - If set to 'default', the request will be processed using the default service tier with a lower uptime SLA and no latency guarentee.\n - When not set, the default behavior is 'auto'.\n\n When this parameter is set, the response body will include the `service_tier` utilized.\n" + ), + ] = None + stop: Annotated[ + Union[Optional[str], Stop1], + Field(description="Up to 4 sequences where the API will stop generating further tokens.\n"), + ] = None + stream: Annotated[ + Optional[bool], + Field( + description="If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions).\n" + ), + ] = False + stream_options: Optional[ChatCompletionStreamOptions] = None + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n\nWe generally recommend altering this or `top_p` but not both.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or `temperature` but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] = 1 + tools: Annotated[ + Optional[List[ChatCompletionTool]], + Field( + description="A list of tools the model may call. Currently, only functions are supported as a tool. Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported.\n" + ), + ] = None + tool_choice: Optional[ChatCompletionToolChoiceOption] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + user: Annotated[ + Optional[str], + Field( + description="A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids).\n", + examples=["user-1234"], + ), + ] = None + function_call: Annotated[ + Optional[Union[Literal["none", "auto"], ChatCompletionFunctionCallOption]], + Field( + description='Deprecated in favor of `tool_choice`.\n\nControls which (if any) function is called by the model.\n`none` means the model will not call a function and instead generates a message.\n`auto` means the model can pick between generating a message or calling a function.\nSpecifying a particular function via `{"name": "my_function"}` forces the model to call that function.\n\n`none` is the default when no functions are present. `auto` is the default if functions are present.\n' + ), + ] = None + functions: Annotated[ + Optional[List[ChatCompletionFunctions]], + Field( + description="Deprecated in favor of `tools`.\n\nA list of functions the model may generate JSON inputs for.\n", + max_length=128, + min_length=1, + ), + ] = None + + +class CreateThreadAndRunRequest(BaseModel): + model_config = ConfigDict( + extra="forbid", + ) + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run." + ), + ] + thread: Annotated[ + Optional[CreateThreadRequest], + Field(description="If no thread is provided, an empty thread will be created."), + ] = None + model: Annotated[ + Optional[ + Union[ + Optional[str], + Literal[ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ], + ] + ], + Field( + description="The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used.", + examples=["gpt-4o"], + ), + ] = None + instructions: Annotated[ + Optional[str], + Field( + description="Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis." + ), + ] = None + tools: Annotated[ + Optional[List[Union[AssistantToolsCode, AssistantToolsFileSearch, AssistantToolsFunction]]], + Field( + description="Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis.", + max_length=20, + ), + ] = None + tool_resources: Annotated[ + Optional[ToolResources3], + Field( + description="A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs.\n" + ), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] = None + temperature: Annotated[ + Optional[float], + Field( + description="What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic.\n", + examples=[1], + ge=0.0, + le=2.0, + ), + ] = 1 + top_p: Annotated[ + Optional[float], + Field( + description="An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered.\n\nWe generally recommend altering this or temperature but not both.\n", + examples=[1], + ge=0.0, + le=1.0, + ), + ] = 1 + stream: Annotated[ + Optional[bool], + Field( + description="If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message.\n" + ), + ] = None + max_prompt_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] = None + max_completion_tokens: Annotated[ + Optional[int], + Field( + description="The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info.\n", + ge=256, + ), + ] = None + truncation_strategy: Optional[TruncationObject] = None + tool_choice: Optional[AssistantsApiToolChoiceOption] = None + parallel_tool_calls: Optional[ParallelToolCalls] = None + response_format: Optional[AssistantsApiResponseFormatOption] = None + + +class RunStepObject(BaseModel): + id: Annotated[ + str, + Field( + description="The identifier of the run step, which can be referenced in API endpoints." + ), + ] + object: Annotated[ + Literal["thread.run.step"], + Field(description="The object type, which is always `thread.run.step`."), + ] + created_at: Annotated[ + int, + Field(description="The Unix timestamp (in seconds) for when the run step was created."), + ] + assistant_id: Annotated[ + str, + Field( + description="The ID of the [assistant](/docs/api-reference/assistants) associated with the run step." + ), + ] + thread_id: Annotated[ + str, + Field(description="The ID of the [thread](/docs/api-reference/threads) that was run."), + ] + run_id: Annotated[ + str, + Field( + description="The ID of the [run](/docs/api-reference/runs) that this run step is a part of." + ), + ] + type: Annotated[ + Literal["message_creation", "tool_calls"], + Field( + description="The type of run step, which can be either `message_creation` or `tool_calls`." + ), + ] + status: Annotated[ + Literal["in_progress", "cancelled", "failed", "completed", "expired"], + Field( + description="The status of the run step, which can be either `in_progress`, `cancelled`, `failed`, `completed`, or `expired`." + ), + ] + step_details: Annotated[ + Union[RunStepDetailsMessageCreationObject, RunStepDetailsToolCallsObject], + Field(description="The details of the run step."), + ] + last_error: Annotated[ + Optional[LastError1], + Field( + description="The last error associated with this run step. Will be `null` if there are no errors." + ), + ] + expired_at: Annotated[ + Optional[int], + Field( + description="The Unix timestamp (in seconds) for when the run step expired. A step is considered expired if the parent run is expired." + ), + ] = None + cancelled_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the run step was cancelled."), + ] = None + failed_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the run step failed."), + ] = None + completed_at: Annotated[ + Optional[int], + Field(description="The Unix timestamp (in seconds) for when the run step completed."), + ] = None + metadata: Annotated[ + Optional[Dict[str, Any]], + Field( + description="Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long.\n" + ), + ] + usage: RunStepCompletionUsage + + +class Delta1(BaseModel): + step_details: Annotated[ + Optional[ + Union[ + RunStepDeltaStepDetailsMessageCreationObject, + RunStepDeltaStepDetailsToolCallsObject, + ] + ], + Field(description="The details of the run step."), + ] = None + + +class RunStepDeltaObject(BaseModel): + id: Annotated[ + str, + Field( + description="The identifier of the run step, which can be referenced in API endpoints." + ), + ] + object: Annotated[ + Literal["thread.run.step.delta"], + Field(description="The object type, which is always `thread.run.step.delta`."), + ] + delta: Annotated[ + Delta1, + Field(description="The delta containing the fields that have changed on the run step."), + ] + + +class ListRunStepsResponse(BaseModel): + object: Annotated[str, Field(examples=["list"])] + data: List[RunStepObject] + first_id: Annotated[str, Field(examples=["step_abc123"])] + last_id: Annotated[str, Field(examples=["step_abc456"])] + has_more: Annotated[bool, Field(examples=[False])] + + +class RunStepStreamEvent1(BaseModel): + event: Literal["thread.run.step.created"] + data: RunStepObject + + +class RunStepStreamEvent2(BaseModel): + event: Literal["thread.run.step.in_progress"] + data: RunStepObject + + +class RunStepStreamEvent3(BaseModel): + event: Literal["thread.run.step.delta"] + data: RunStepDeltaObject + + +class RunStepStreamEvent4(BaseModel): + event: Literal["thread.run.step.completed"] + data: RunStepObject + + +class RunStepStreamEvent5(BaseModel): + event: Literal["thread.run.step.failed"] + data: RunStepObject + + +class RunStepStreamEvent6(BaseModel): + event: Literal["thread.run.step.cancelled"] + data: RunStepObject + + +class RunStepStreamEvent7(BaseModel): + event: Literal["thread.run.step.expired"] + data: RunStepObject + + +class RunStepStreamEvent( + RootModel[ + Union[ + RunStepStreamEvent1, + RunStepStreamEvent2, + RunStepStreamEvent3, + RunStepStreamEvent4, + RunStepStreamEvent5, + RunStepStreamEvent6, + RunStepStreamEvent7, + ] + ] +): + root: Union[ + RunStepStreamEvent1, + RunStepStreamEvent2, + RunStepStreamEvent3, + RunStepStreamEvent4, + RunStepStreamEvent5, + RunStepStreamEvent6, + RunStepStreamEvent7, + ] + + +class AssistantStreamEvent( + RootModel[ + Union[ + ThreadStreamEvent, + RunStreamEvent, + RunStepStreamEvent, + MessageStreamEvent, + ErrorEvent, + DoneEvent, + ] + ] +): + root: Annotated[ + Union[ + ThreadStreamEvent, + RunStreamEvent, + RunStepStreamEvent, + MessageStreamEvent, + ErrorEvent, + DoneEvent, + ], + Field( + description='Represents an event emitted when streaming a Run.\n\nEach event in a server-sent events stream has an `event` and `data` property:\n\n```\nevent: thread.created\ndata: {"id": "thread_123", "object": "thread", ...}\n```\n\nWe emit events whenever a new object is created, transitions to a new state, or is being\nstreamed in parts (deltas). For example, we emit `thread.run.created` when a new run\nis created, `thread.run.completed` when a run completes, and so on. When an Assistant chooses\nto create a message during a run, we emit a `thread.message.created event`, a\n`thread.message.in_progress` event, many `thread.message.delta` events, and finally a\n`thread.message.completed` event.\n\nWe may add additional events over time, so we recommend handling unknown events gracefully\nin your code. See the [Assistants API quickstart](/docs/assistants/overview) to learn how to\nintegrate the Assistants API with streaming.\n' + ), + ] diff --git a/server/llm_engine_server/core/docker/__init__.py b/model-engine/model_engine_server/core/__init__.py similarity index 100% rename from server/llm_engine_server/core/docker/__init__.py rename to model-engine/model_engine_server/core/__init__.py diff --git a/server/llm_engine_server/core/utils/__init__.py b/model-engine/model_engine_server/core/auth/__init__.py similarity index 100% rename from server/llm_engine_server/core/utils/__init__.py rename to model-engine/model_engine_server/core/auth/__init__.py diff --git a/server/llm_engine_server/core/auth/authentication_repository.py b/model-engine/model_engine_server/core/auth/authentication_repository.py similarity index 52% rename from server/llm_engine_server/core/auth/authentication_repository.py rename to model-engine/model_engine_server/core/auth/authentication_repository.py index ce1cf9b9..2d3b591e 100644 --- a/server/llm_engine_server/core/auth/authentication_repository.py +++ b/model-engine/model_engine_server/core/auth/authentication_repository.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from dataclasses import dataclass +from dataclasses import dataclass, field from typing import Optional @@ -7,7 +7,9 @@ class User: user_id: str team_id: str - is_privileged_user: bool + email: Optional[str] = field(repr=False, default=None) + team_email: Optional[str] = field(repr=False, default=None) + is_privileged_user: bool = False class AuthenticationRepository(ABC): @@ -16,26 +18,21 @@ class AuthenticationRepository(ABC): With the context of the Model Primitive service, this just refers to a (user_id, team_id) pair. """ + @staticmethod @abstractmethod - def get_auth_from_user_id(self, user_id: str) -> Optional[User]: + def is_allowed_team(team: str) -> bool: """ - Returns authentication information associated with a given user_id. + Returns whether the provided team is an allowed team. """ @abstractmethod - def get_auth_from_api_key(self, api_key: str) -> Optional[User]: + def get_auth_from_username(self, username: str) -> Optional[User]: """ - Returns authentication information associated with a given api_key. + Returns authentication information associated with a given Basic Auth username. """ @abstractmethod - async def get_auth_from_user_id_async(self, user_id: str) -> Optional[User]: + async def get_auth_from_username_async(self, username: str) -> Optional[User]: """ - Returns authentication information associated with a given user_id. - """ - - @abstractmethod - async def get_auth_from_api_key_async(self, api_key: str) -> Optional[User]: - """ - Returns authentication information associated with a given api_key. + Returns authentication information associated with a given Basic Auth username. """ diff --git a/model-engine/model_engine_server/core/auth/fake_authentication_repository.py b/model-engine/model_engine_server/core/auth/fake_authentication_repository.py new file mode 100644 index 00000000..ff38e768 --- /dev/null +++ b/model-engine/model_engine_server/core/auth/fake_authentication_repository.py @@ -0,0 +1,22 @@ +from typing import Dict, Optional + +from model_engine_server.core.auth.authentication_repository import AuthenticationRepository, User + + +class FakeAuthenticationRepository(AuthenticationRepository): + def __init__(self, user_team_override: Optional[Dict[str, str]] = None): + if user_team_override is None: + user_team_override = {} + self.user_team_override = user_team_override + + @staticmethod + def is_allowed_team(team: str) -> bool: + return True + + def get_auth_from_username(self, username: str) -> Optional[User]: + team_id = self.user_team_override.get(username, username) + return User(user_id=username, team_id=team_id, is_privileged_user=True) + + async def get_auth_from_username_async(self, username: str) -> Optional[User]: + team_id = self.user_team_override.get(username, username) + return User(user_id=username, team_id=team_id, is_privileged_user=True) diff --git a/server/llm_engine_server/db/__init__.py b/model-engine/model_engine_server/core/aws/__init__.py similarity index 100% rename from server/llm_engine_server/db/__init__.py rename to model-engine/model_engine_server/core/aws/__init__.py diff --git a/server/llm_engine_server/core/aws/roles.py b/model-engine/model_engine_server/core/aws/roles.py similarity index 91% rename from server/llm_engine_server/core/aws/roles.py rename to model-engine/model_engine_server/core/aws/roles.py index 52c71d48..d33efeca 100644 --- a/server/llm_engine_server/core/aws/roles.py +++ b/model-engine/model_engine_server/core/aws/roles.py @@ -11,7 +11,7 @@ import boto3 from boto3 import Session, client from botocore.client import BaseClient -from llm_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.loggers import logger_name, make_logger logger = make_logger(logger_name()) @@ -73,9 +73,6 @@ def client(self, client_type: str, region_name: str = "us-west-2") -> BaseClient """Creates the specified Boto3 :param:`client_type` using the AWS credentials. The :param:`client_type` parameter is any valid value for `boto3.client` (e.g. `"s3"`). - - NOTE: Use the us-west-2 region unless you are absolutely sure you require a different region. - All Scale AWS services are in US West 2. """ return boto3.client( client_type, @@ -122,15 +119,6 @@ def session(role: Optional[str], session_type: SessionT = Session) -> SessionT: :param:`session_type` defines the type of session to return. Most users will use the default boto3 type. Some users required a special type (e.g aioboto3 session). - - Includes fall-back logic to work with setups that do not use a credentials file - in the .aws folder in the user's home folder. In this setting, it is ok for :param:`role` - to be an ARN. Otherwise, the `profile_to_arn` mapping in `scaleml.config` is used to - locate the correct ARN for the given AWS profile name. - - NOTE: The fall-back is required for this to work with setups that use `aws-okta`. - - :raises: botocore.exceptions.ProfileNotFound, ValueError """ # Do not assume roles in CIRCLECI if os.getenv("CIRCLECI"): @@ -178,7 +166,6 @@ def _session_aws_okta( return sesh -# returns scale user (e.g. pranav.pillai) def get_current_user() -> str: """Uses AWS sts to obtain the profile name of the currently authenticated AWS account.""" arn = client("sts").get_caller_identity().get("Arn") diff --git a/server/llm_engine_server/core/aws/secrets.py b/model-engine/model_engine_server/core/aws/secrets.py similarity index 60% rename from server/llm_engine_server/core/aws/secrets.py rename to model-engine/model_engine_server/core/aws/secrets.py index 4d8e8941..0637b121 100644 --- a/server/llm_engine_server/core/aws/secrets.py +++ b/model-engine/model_engine_server/core/aws/secrets.py @@ -1,27 +1,24 @@ """AWS secrets module.""" + import json from functools import lru_cache from typing import Optional import boto3 from botocore.exceptions import ClientError -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) @lru_cache(maxsize=2) def get_key_file(secret_name: str, aws_profile: Optional[str] = None): if aws_profile is not None: session = boto3.Session(profile_name=aws_profile) - secret_manager = session.client( - "secretsmanager", region_name=ml_infra_config().default_region - ) + secret_manager = session.client("secretsmanager", region_name=infra_config().default_region) else: - secret_manager = boto3.client( - "secretsmanager", region_name=ml_infra_config().default_region - ) + secret_manager = boto3.client("secretsmanager", region_name=infra_config().default_region) try: secret_value = json.loads( secret_manager.get_secret_value(SecretId=secret_name)["SecretString"] diff --git a/server/llm_engine_server/core/aws/storage_client.py b/model-engine/model_engine_server/core/aws/storage_client.py similarity index 86% rename from server/llm_engine_server/core/aws/storage_client.py rename to model-engine/model_engine_server/core/aws/storage_client.py index 526eda3a..814b00c4 100644 --- a/server/llm_engine_server/core/aws/storage_client.py +++ b/model-engine/model_engine_server/core/aws/storage_client.py @@ -3,9 +3,9 @@ import smart_open from botocore.client import BaseClient -from llm_engine_server.core.aws.roles import session -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.aws.roles import session +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger logger = make_logger(logger_name()) @@ -20,7 +20,7 @@ def sync_storage_client(**kwargs) -> BaseClient: - return session(ml_infra_config().profile_ml_worker).client("s3", **kwargs) + return session(infra_config().profile_ml_worker).client("s3", **kwargs) # type: ignore def open(uri: str, mode: str = "rt", **kwargs) -> IO: # pylint: disable=redefined-builtin @@ -30,10 +30,7 @@ def open(uri: str, mode: str = "rt", **kwargs) -> IO: # pylint: disable=redefin def sync_storage_client_keepalive( - s3_client: BaseClient, - buckets: Iterable[str], - interval: int, - is_cancelled: Callable[[], bool], + s3_client: BaseClient, buckets: Iterable[str], interval: int, is_cancelled: Callable[[], bool] ) -> None: """Keeps connection pool warmed up for access on list of S3 buckets. diff --git a/model-engine/model_engine_server/core/celery/__init__.py b/model-engine/model_engine_server/core/celery/__init__.py new file mode 100644 index 00000000..3368bc69 --- /dev/null +++ b/model-engine/model_engine_server/core/celery/__init__.py @@ -0,0 +1,19 @@ +from typing import Sequence + +from .app import ( + DEFAULT_TASK_VISIBILITY_SECONDS, + TaskVisibility, + celery_app, + get_all_db_indexes, + get_redis_host_port, + inspect_app, +) + +__all__: Sequence[str] = ( + "celery_app", + "get_all_db_indexes", + "get_redis_host_port", + "inspect_app", + "TaskVisibility", + "DEFAULT_TASK_VISIBILITY_SECONDS", +) diff --git a/model-engine/model_engine_server/core/celery/abs.py b/model-engine/model_engine_server/core/celery/abs.py new file mode 100644 index 00000000..ea303947 --- /dev/null +++ b/model-engine/model_engine_server/core/celery/abs.py @@ -0,0 +1,23 @@ +from azure.core.exceptions import ResourceExistsError +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient +from celery.backends.azureblockblob import AzureBlockBlobBackend as DefaultAzureBlockBlobBackend +from kombu.utils import cached_property + + +class AzureBlockBlobBackend(DefaultAzureBlockBlobBackend): + @cached_property + def _blob_service_client(self): + client = BlobServiceClient( + f"https://{self._connection_string}.blob.core.windows.net", + credential=DefaultAzureCredential(), + connection_timeout=self._connection_timeout, + read_timeout=self._read_timeout, + ) + + try: + client.create_container(name=self._container_name) + except ResourceExistsError: + pass + + return client diff --git a/server/llm_engine_server/core/celery/app.py b/model-engine/model_engine_server/core/celery/app.py similarity index 90% rename from server/llm_engine_server/core/celery/app.py rename to model-engine/model_engine_server/core/celery/app.py index 924d268f..af7790d1 100644 --- a/server/llm_engine_server/core/celery/app.py +++ b/model-engine/model_engine_server/core/celery/app.py @@ -9,9 +9,10 @@ from celery.app import backends from celery.app.control import Inspect from celery.result import AsyncResult -from llm_engine_server.core.aws.roles import session -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import ( +from model_engine_server.core.aws.roles import session +from model_engine_server.core.aws.secrets import get_key_file +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import ( CustomJSONFormatter, logger_name, make_logger, @@ -25,7 +26,13 @@ # This is because the Celery code does not actually work when you try and # override the backend with a class instead of a URL, despite the fact # that the `backend` constructor arg type is a Union[str, Type[celery.backends.base.Backend]] -backends.BACKEND_ALIASES["s3"] = "llm_engine_server.core.celery.s3:S3Backend" +backends.BACKEND_ALIASES["s3"] = "model_engine_server.core.celery.s3:S3Backend" +backends.BACKEND_ALIASES["azureblockblob"] = ( + "model_engine_server.core.celery.abs:AzureBlockBlobBackend" +) + + +DEFAULT_TASK_VISIBILITY_SECONDS = 86400 @unique @@ -60,7 +67,7 @@ class TaskVisibility(IntEnum): 2. When making requests to such deployment, you'll have to do: ```python - from scaleml.io.celery import TaskVisibility, celery_app + from model_engine_server.core.celery.app import TaskVisibility, celery_app app = celery_app(None, task_visibility=TaskVisibility.VISIBILITY_1M) future_result = app.send_task("some.task.name", args=["some", "args"], queue="some-queue") ``` @@ -171,7 +178,7 @@ def get_redis_host_port(): port = os.getenv("REDIS_PORT") # In the case of k8s, pick the right endpoint based on the config elif os.getenv("KUBERNETES_SERVICE_HOST"): - host = ml_infra_config().redis_host + host = infra_config().redis_host port = 6379 # For debugging purposes elif os.getenv("USE_REDIS_LOCALHOST") == "1": @@ -180,8 +187,8 @@ def get_redis_host_port(): port = 6379 # In the case of local testing, pick the right endpoint based on the config elif os.getenv("KUBECONFIG"): - logger.info(f"Inferring redis host from config env: {ml_infra_config().env}") - host = f"redis-elasticache-message-broker.{ml_infra_config().dns_host_domain}" + logger.info(f"Inferring redis host from config env: {infra_config().env}") + host = f"redis-elasticache-message-broker.{infra_config().dns_host_domain}" port = 6379 logger.info(f"Using Redis host and port: {host}:{port}") @@ -189,6 +196,17 @@ def get_redis_host_port(): def get_redis_endpoint(db_index: int = 0) -> str: + if infra_config().redis_aws_secret_name is not None: + logger.info("Using infra_config().redis_aws_secret_name for Redis endpoint") + creds = get_key_file(infra_config().redis_aws_secret_name) # Use default role + scheme = creds.get("scheme", "redis://") + host = creds["host"] + port = creds["port"] + query_params = creds.get("query_params", "") + auth_token = creds.get("auth_token", None) + if auth_token is not None: + return f"{scheme}:{auth_token}@{host}:{port}/{db_index}{query_params}" + return f"{scheme}{host}:{port}/{db_index}{query_params}" host, port = get_redis_host_port() auth_token = os.getenv("REDIS_AUTH_TOKEN") if auth_token: @@ -292,7 +310,7 @@ def celery_app( 2. When making requests to such deployment, you'll have to do: ```python - from scaleml.io.celery import TaskVisibility, celery_app + from model_engine_server.core.celery import TaskVisibility, celery_app app = celery_app(None, task_visibility=TaskVisibility.VISIBILITY_1M) future_result = app.send_task("some.task.name", args=["some", "args"], queue="some-queue") ``` @@ -342,21 +360,21 @@ def celery_app( # FIXME: serializer. Until we figure out how to run as a non-root user, it might be better to avoid pickle. :param s3_bucket: [optional] Bucket name to store task results when using S3 as backend. The results uri will be - "s3:////...". Defaults to "scale-ml" (s3://scale-ml/tmp/celery/). + "s3:////...". :param s3_base_path: [optional] Base path for task results when using S3 as backend. The results uri will be - "s3:////...". Defaults to "tmp/celery/" (s3://scale-ml/tmp/celery/). + "s3:////...". - :param backend_protocol: [optional] Backend protocol to use, currently supports "s3" and "redis". + :param backend_protocol: [optional] Backend protocol to use, currently supports "s3", "redis", and "abs". Defaults to "s3". Redis might be faster than S3 but is not persistent, so using "redis" is discouraged. If you do end up using this, make sure you set up `result_expires` (https://docs.celeryproject.org/en/stable/userguide/configuration.html#result-expires) to something reasonable (1 day by default) and run `celery beat` periodically to clear expired results from Redis. Visit https://docs.celeryproject.org/en/stable/userguide/periodic-tasks.html to learn more about celery beat - :param broker_type: [defaults to "redis"] The broker type. We currently support "redis" and "sqs". + :param broker_type: [defaults to "redis"] The broker type. We currently support "redis", "sqs", and "servicebus". - :param aws_role: [optional] AWS role to use. If none, will default to default for s3 backends, + :param aws_role: [optional] AWS role to use. :param extra_changes: Extra keyword arguments to Celery app. Visit https://docs.celeryproject.org/en/stable/userguide/configuration.html to see options. @@ -430,7 +448,7 @@ def celery_app( } if s3_bucket is None: - s3_bucket = ml_infra_config().s3_bucket + s3_bucket = infra_config().s3_bucket backend_url, extra_conf_changes = _get_backend_url_and_conf( backend_protocol, @@ -473,6 +491,11 @@ def _get_broker_endpoint_and_transport_options( # Going to try this with defaults first. out_broker_transport_options["region"] = os.environ.get("AWS_REGION", "us-west-2") + # changing wait_time_seconds from the default of 10 based on https://github.com/celery/celery/discussions/7283 + # goal is to prevent async requests from being stuck in pending when workers die; the hypothesis is that this is caused by SQS long polling + out_broker_transport_options["wait_time_seconds"] = 0 + out_broker_transport_options["polling_interval"] = 5 + # NOTE: The endpoints should ideally use predefined queues. However, the sender probably needs the flexibility # of not requiring predefined queues. # assert ( @@ -481,9 +504,14 @@ def _get_broker_endpoint_and_transport_options( # Plain "sqs://" signifies to use instance metadata. return "sqs://", out_broker_transport_options + if broker_type == "servicebus": + return ( + f"azureservicebus://DefaultAzureCredential@{os.getenv('SERVICEBUS_NAMESPACE')}.servicebus.windows.net", + out_broker_transport_options, + ) raise ValueError( - f"Only 'redis' and 'sqs' are supported values for broker_type, got value {broker_type}" + f"Only 'redis', 'sqs', and 'servicebus' are supported values for broker_type, got value {broker_type}" ) @@ -504,7 +532,7 @@ def _get_backend_url_and_conf( elif backend_protocol == "s3": backend_url = "s3://" if aws_role is None: - aws_session = session(ml_infra_config().profile_ml_worker) + aws_session = session(infra_config().profile_ml_worker) else: aws_session = session(aws_role) out_conf_changes.update( @@ -514,9 +542,11 @@ def _get_backend_url_and_conf( "s3_base_path": s3_base_path, } ) + elif backend_protocol == "abs": + backend_url = f"azureblockblob://{os.getenv('ABS_ACCOUNT_NAME')}" else: raise ValueError( - f'Unknown backend protocol "{backend_protocol}". Should be one of ["s3", "redis"].' + f'Unknown backend protocol "{backend_protocol}". Should be one of ["s3", "redis", "abs].' ) return backend_url, out_conf_changes diff --git a/model-engine/model_engine_server/core/celery/celery_autoscaler.py b/model-engine/model_engine_server/core/celery/celery_autoscaler.py new file mode 100644 index 00000000..54e3c5bc --- /dev/null +++ b/model-engine/model_engine_server/core/celery/celery_autoscaler.py @@ -0,0 +1,693 @@ +import asyncio as aio +import dataclasses +import hashlib +import logging +import os +import time +from abc import ABC, abstractmethod +from bisect import bisect +from collections import defaultdict +from concurrent.futures import ThreadPoolExecutor +from math import ceil +from typing import Any, DefaultDict, Dict, List, Set, Tuple + +import aioredis +import stringcase +from azure.core.exceptions import ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.servicebus.management import ServiceBusAdministrationClient +from celery.app.control import Inspect +from datadog import statsd +from kubernetes_asyncio import client +from kubernetes_asyncio import config as kube_config +from kubernetes_asyncio.client.rest import ApiException +from kubernetes_asyncio.config.config_exception import ConfigException +from model_engine_server.core.aws.roles import session +from model_engine_server.core.celery import ( + TaskVisibility, + celery_app, + get_all_db_indexes, + get_redis_host_port, + inspect_app, +) +from model_engine_server.core.loggers import logger_name, make_logger + + +def excluded_namespaces(): + try: + from plugins.celery_autoscaler_dependencies import CELERY_AUTOSCALER_EXCLUDED_NAMESPACES + + return CELERY_AUTOSCALER_EXCLUDED_NAMESPACES + except ModuleNotFoundError: + return [] + + +ELASTICACHE_REDIS_BROKER = "redis-elasticache-message-broker-master" +SQS_BROKER = "sqs-message-broker-master" +SERVICEBUS_BROKER = "servicebus-message-broker-master" + +UPDATE_DEPLOYMENT_MAX_RETRIES = 10 + +SQS_SAMPLE_COUNT = 10 + +logger = make_logger(logger_name()) + +autoscaler_broker = os.environ.get("BROKER_NAME", SQS_BROKER) +aws_profile = os.environ.get("AWS_PROFILE") +aws_region = os.environ.get("AWS_REGION", "us-west-2") + + +@dataclasses.dataclass +class CeleryAutoscalerParams: + queue: str + broker: str = SQS_BROKER + task_visibility: TaskVisibility = TaskVisibility.VISIBILITY_1H + per_worker: int = 1 + min_workers: int = 0 + max_workers: int = 1 + + +def _hash_any_to_int(data: Any): + return int(hashlib.md5(str(data).encode()).hexdigest(), 16) # nosemgrep + + +async def list_deployments(core_api, apps_api) -> Dict[Tuple[str, str], CeleryAutoscalerParams]: + namespaces = await core_api.list_namespace() + celery_deployments_params = {} + for namespace in namespaces.items: + namespace_name = namespace.metadata.name + if namespace_name in excluded_namespaces(): + continue + namespace_start_time = time.time() + deployments = await apps_api.list_namespaced_deployment(namespace=namespace_name) + logger.info( + f"list_namespaced_deployment with {namespace_name} took {time.time() - namespace_start_time} seconds" + ) + for deployment in deployments.items: + deployment_name = deployment.metadata.name + annotations = deployment.metadata.annotations + + if not annotations: + continue + + # Parse parameters + params = {} + + if "celery.scaleml.autoscaler/broker" in annotations: + deployment_broker = annotations["celery.scaleml.autoscaler/broker"] + else: + deployment_broker = ELASTICACHE_REDIS_BROKER + + if deployment_broker != autoscaler_broker: + logger.debug( + f"Skipping deployment {deployment_name}; deployment's broker {deployment_broker} is not {autoscaler_broker}" + ) + continue + + for f in dataclasses.fields(CeleryAutoscalerParams): + k = f.name + v = annotations.get(f"celery.scaleml.autoscaler/{stringcase.camelcase(k)}") + if not v: + continue + + try: + if k == "task_visibility": + v = TaskVisibility.from_name(v) + v = f.type(v) + except (ValueError, KeyError): + logger.exception(f"Unable to convert {f.name}: {v} to {f.type}") + + params[k] = v + + try: + celery_autoscaler_params = CeleryAutoscalerParams(**params) + except TypeError: + logger.debug( + f"Missing params, skipping deployment : {deployment_name} in {namespace_name}" + ) + continue + + celery_deployments_params[(deployment_name, namespace_name)] = celery_autoscaler_params + + return celery_deployments_params + + +class InstanceLogger(logging.LoggerAdapter): + def process(self, msg, kwargs): + return "%s %s" % (self.extra["name"], msg), kwargs + + +class Instance: + def __init__(self, api, name, namespace, params: CeleryAutoscalerParams, env): + self.api = api + self.name = name + self.namespace = namespace + self.params = params + self.history: List[Tuple[float, float]] = [] + self.logger = InstanceLogger(logger, {"name": name}) + self.env = env + + async def check_queue_size_and_update_deployment(self, queue_size: int) -> None: + workers_wanted = ceil(queue_size / self.params.per_worker) + + time_now = time.monotonic() + self.history.append((workers_wanted, time_now)) + + # Take last 10 minutes + times = [t for _, t in self.history] + evict = bisect(times, time_now - 600) + self.history = self.history[evict:] + + workers_wanted = max(self.history)[0] # type: ignore + workers_wanted = min(self.params.max_workers, workers_wanted) + workers_wanted = max(self.params.min_workers, workers_wanted) + + await self.update_deployment(workers_wanted) + + async def update_deployment(self, workers_wanted) -> None: + for _ in range(UPDATE_DEPLOYMENT_MAX_RETRIES): + try: + dep = await self.api.read_namespaced_deployment( + name=self.name, namespace=self.namespace + ) + + if dep.spec.replicas == workers_wanted: + self.logger.debug("Deployment not updated.") + break + + dep.spec.replicas = workers_wanted + + await self.api.patch_namespaced_deployment( + name=self.name, + namespace=self.namespace, + body=dep, + ) + + self.logger.info(f"Deployment updated. replicas={dep.spec.replicas}") + emit_health_metric("scaling_succeeded", self.env) + return + except ApiException as exc: + if exc.status == 409: + self.logger.info("409 retry") + continue + elif exc.status == 404: + self.logger.warning("404 not found") + return + emit_health_metric("scaling_failed", self.env) + raise + else: + emit_health_metric("scaling_failed", self.env) + raise Exception("Ran out of retries updating deployment") + + +@dataclasses.dataclass +class QueueSizes: + """Obtained from Inspect.active()""" + + active: int = 0 + + """Obtained from Inspect.active() + """ + reserved: int = 0 + + """Computed by summing Redis queue lengths across all db_indexes. + """ + enqueued: int = 0 + + """The sum of all of other fields. + """ + total: int = 0 + + # Ignoring these other Inspect categories for now, since they have a different structure + # from 'active' and 'reserved'. We can add them later if we want - it'd just require some + # more complexity to parse them out. + # + # scheduled: int = 0 + # revoked: int = 0 + # registered: int = 0 + + +@dataclasses.dataclass +class WorkerMetrics: + """ + Key: db_index + Value: number of workers + """ + + worker_counts: DefaultDict[int, int] + + +@dataclasses.dataclass +class BrokerMetrics: + """ + Key: (queue_name, db_index) + Value: QueueSizes + """ + + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes] + + """" + Represents the number of active redis client connections + """ + connection_count: int + + """ + Represents the max number of redis client connections allowed + """ + max_connections: int + + +@dataclasses.dataclass +class Metrics: + worker_metrics: WorkerMetrics + broker_metrics: BrokerMetrics + + +def emit_metrics( + metrics: Metrics, + env: str, +) -> None: + """ + Emits a given mapping of queue sizes to Datadog. + """ + queue_sizes = metrics.broker_metrics.queue_sizes + for q, queue_size in queue_sizes.items(): + queue_name, _ = q + tags = [ + f"env:{env}", + f"queue:{queue_name}", + ] + + for metric_name, metric_value in queue_size.__dict__.items(): + statsd.gauge(f"celery.queue_size.{metric_name}", metric_value, tags=tags) + + # Redis-specific, can be ignored for sqs (worker_counts should be empty anyways) + for db_index, worker_count in metrics.worker_metrics.worker_counts.items(): + task_visibility = TaskVisibility(db_index).name.lower() + tags = [ + f"env:{env}", + f"task_visibility:{task_visibility}", + ] + statsd.gauge("celery.worker_count", worker_count, tags=tags) + + if metrics.broker_metrics.connection_count is not None: + tags = [ + f"env:{env}", + ] + statsd.gauge( + "celery.connection_count", + metrics.broker_metrics.connection_count, + tags=tags, + ) + + if metrics.broker_metrics.max_connections is not None: + tags = [ + f"env:{env}", + ] + statsd.gauge("celery.max_connections", metrics.broker_metrics.max_connections, tags=tags) + + +def emit_health_metric(metric_name: str, env: str): + tags = [f"env:{env}"] + statsd.increment(f"celery_autoscaler.{metric_name}", tags=tags) + + +class AutoscalerBroker(ABC): + """ + Base class for autoscaler brokers. + """ + + @abstractmethod + async def get_broker_metrics( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ) -> BrokerMetrics: + """ + Calculates broker related metrics. + + Args: + queues: a set of (queue_name, db_index) + queue_sizes: number of active and reserved tasks for each queue + + Returns: broker metrics + """ + + +class RedisBroker(AutoscalerBroker): + def __init__(self, use_elasticache: bool, initialized: bool = False): + self.use_elasticache = use_elasticache + self.initialized = initialized + + async def _init_client(self): + ( + host, + port, + ) = ( + get_redis_host_port() + ) # Switches the redis instance based on CELERY_ELASTICACHE_ENABLED's value + self.redis = { + db_index: aioredis.client.Redis.from_url(f"redis://{host}:{port}/{db_index}") + for db_index in get_all_db_indexes() + } + self.initialized = True + + async def _get_queue_sizes( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ): + if not self.initialized: + await self._init_client() + + for queue_name, db_index in queues: + q = (queue_name, db_index) + enqueued = await self.redis[db_index].llen(queue_name) + queue_sizes[q].enqueued += enqueued + queue_sizes[q].total += enqueued + return queue_sizes + + async def _get_connection_count(self): + redis_client = next(iter(self.redis.values()), None) # get any redis client + + if redis_client is not None: + if ( + self.use_elasticache + ): # We are using elasticache which doesn't allow us to do `CONFIG GET` + info = await redis_client.info() + connection_count = info.get("connected_clients") + max_connections = info.get("maxclients") + else: + (info, config) = await aio.gather( + redis_client.info(), + redis_client.config_get("maxclients"), + ) + max_connections = config.get("maxclients") + connection_count = info.get("connected_clients") + + return connection_count, max_connections + + async def get_broker_metrics( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ) -> BrokerMetrics: + queue_sizes = await self._get_queue_sizes(queues, queue_sizes) + connection_count, max_connections = await self._get_connection_count() + return BrokerMetrics( + queue_sizes=queue_sizes, + connection_count=connection_count, + max_connections=max_connections, + ) + + +class SQSBroker(AutoscalerBroker): + @staticmethod + def _get_sqs_queue_size(queue_name: str): + sqs_client = session(aws_profile).client("sqs", region_name=aws_region) + try: + total_start_time = time.time() + queue_size_hist = [] + reserved_size_hist = [] + # We intentionally launch several requests to the same queue. + # We have found multiple samples results in more accurate length estimates compared to a single request. + # Performance-wise: The first request takes ~0.5s, subsequent requests take ~0.005s + for _ in range(SQS_SAMPLE_COUNT): + response = sqs_client.get_queue_attributes( + QueueUrl=queue_name, + AttributeNames=[ + "ApproximateNumberOfMessages", + "ApproximateNumberOfMessagesNotVisible", + ], + ) + queue_size_hist.append(int(response["Attributes"]["ApproximateNumberOfMessages"])) + reserved_size_hist.append( + int(response["Attributes"]["ApproximateNumberOfMessagesNotVisible"]) + ) + total_end_time = time.time() + queue_size = max(queue_size_hist) + # SQS's ApproximateNumberOfMessagesNotVisible should correspond to celery's + # number of active + number of reserved tasks + reserved_size = max(reserved_size_hist) + logger.info( + f"SQS {queue_name} total: {total_end_time - total_start_time} seconds, queue size {queue_size}, reserved size {reserved_size}" + ) + + except sqs_client.exceptions.QueueDoesNotExist as e: + logger.info(f"Queue does not exist {queue_name}: {e}") + queue_size = 0 + reserved_size = 0 + except Exception as e: + logger.error(f"Failed to get queue attributes {queue_name}: {e}") + queue_size = 0 + reserved_size = 0 + return queue_size, reserved_size + + def _get_queue_sizes( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ): + queue_names = [queue_name for queue_name, _ in queues] + with ThreadPoolExecutor() as executor: + results = executor.map(SQSBroker._get_sqs_queue_size, queue_names) + + for q, (enqueued, reserved) in zip(queues, results): + queue_sizes[q].enqueued += enqueued + queue_sizes[q].reserved += reserved + queue_sizes[q].total += enqueued + reserved + return queue_sizes + + async def get_broker_metrics( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ) -> BrokerMetrics: + queue_sizes = self._get_queue_sizes(queues, queue_sizes) + return BrokerMetrics( + queue_sizes=queue_sizes, + connection_count=None, + max_connections=None, + ) # connection_count and max_connections are redis-specific metrics + + +class ASBBroker(AutoscalerBroker): + @staticmethod + def _get_asb_queue_size(queue_name: str): + with ServiceBusAdministrationClient( + f"{os.getenv('SERVICEBUS_NAMESPACE')}.servicebus.windows.net", + credential=DefaultAzureCredential(), + ) as client: + try: + queue_attributes = client.get_queue_runtime_properties(queue_name=queue_name) + active_queue_size = queue_attributes.active_message_count + + logger.info(f"ASB {queue_name} total: active queue size {active_queue_size}") + except ResourceNotFoundError as e: + logger.info(f"Queue does not exist {queue_name}: {e}") + active_queue_size = 0 + except Exception as e: + logger.error(f"Failed to get queue attributes {queue_name}: {e}") + active_queue_size = 0 + + return active_queue_size + + def _get_queue_sizes( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ): + queue_names = [queue_name for queue_name, _ in queues] + with ThreadPoolExecutor() as executor: + results = executor.map(ASBBroker._get_asb_queue_size, queue_names) + + for q, active_queue_size in zip(queues, results): + queue_sizes[q].enqueued += active_queue_size + queue_sizes[q].total += active_queue_size + return queue_sizes + + async def get_broker_metrics( + self, + queues: Set[Tuple[str, int]], + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes], + ) -> BrokerMetrics: + queue_sizes = self._get_queue_sizes(queues, queue_sizes) + return BrokerMetrics( + queue_sizes=queue_sizes, + connection_count=None, + max_connections=None, + ) # connection_count and max_connections are redis-specific metrics + + +def get_worker_metrics( + inspect: Dict[int, Inspect], + queues: Set[Tuple[str, int]], +) -> Tuple[WorkerMetrics, DefaultDict[Tuple[str, int], QueueSizes]]: + """ + Given a set of Celery Inspect results for each db connection, + computes the number of workers for each db connection, and number of active and reserved tasks. + + In the case of SQS this will return no data for queue_sizes/worker counts, as inspect is empty + """ + queue_sizes: DefaultDict[Tuple[str, int], QueueSizes] = defaultdict(QueueSizes) + worker_counts: DefaultDict[int, int] = defaultdict(int) + for db_index, insp in inspect.items(): + insp_categories = { + "active": insp.active(), + "reserved": insp.reserved(), + } + + worker_ping = insp.ping() + if worker_ping: + worker_counts[db_index] = len(worker_ping.values()) + + for insp_key, worker_group in filter(lambda x: x[1], insp_categories.items()): + for task_list in worker_group.values(): + for task in task_list: + queue_name = task["delivery_info"]["routing_key"] + q = (queue_name, db_index) + + if q in queues: + queue_sizes[q].__dict__[insp_key] += 1 + queue_sizes[q].total += 1 + return WorkerMetrics(worker_counts=worker_counts), queue_sizes + + +async def get_metrics( + broker: AutoscalerBroker, + inspect: Dict[int, Inspect], + queues: Set[Tuple[str, int]], +) -> Metrics: + """ + Given a set of Redis db connections and Celery Inspect results for each db connection, + computes worker and broker metrics. + """ + + worker_metrics, active_reserved_queue_sizes = get_worker_metrics(inspect, queues) + broker_metrics = await broker.get_broker_metrics(queues, active_reserved_queue_sizes) + + return Metrics( + worker_metrics=worker_metrics, + broker_metrics=broker_metrics, + ) + + +async def main(): + instances: Dict[Tuple[str, str], Instance] = {} + try: + kube_config.load_incluster_config() + except ConfigException: + logger.info("No incluster kubernetes config, falling back to local") + await kube_config.load_kube_config() + + core_api = client.CoreV1Api() + apps_api = client.AppsV1Api() + + BROKER_NAME_TO_CLASS = { + ELASTICACHE_REDIS_BROKER: RedisBroker(use_elasticache=True), + SQS_BROKER: SQSBroker(), + SERVICEBUS_BROKER: ASBBroker(), + } + + broker = BROKER_NAME_TO_CLASS[autoscaler_broker] + broker_type = ( + "redis" + if isinstance(broker, RedisBroker) + else "sqs" if isinstance(broker, SQSBroker) else "servicebus" + ) + + if broker_type == "redis": + inspect = { + db_index: inspect_app( + app=celery_app( + None, broker_type=broker_type, task_visibility=db_index, aws_role=aws_profile + ) + ) + for db_index in get_all_db_indexes() + } + elif broker_type == "sqs": + # for sqs we will get active/reserved counts directly from sqs as opposed to using + # an inspect object + inspect = {} + elif broker_type == "servicebus": + inspect = { + 0: inspect_app(app=celery_app(None, broker_type=broker_type, backend_protocol="abs")) + } + else: + raise ValueError("broker_type not redis, sqs, or servicebus; how did we get here?") + + env = os.getenv("DD_ENV") + instance_count = int(os.getenv("POD_NAME", "pod-0").split("-")[-1]) + num_shards = int(os.getenv("NUM_SHARDS", 1)) + + env = f"{env}-{autoscaler_broker}" + + while True: + try: + loop_start = time.time() + deployments = await list_deployments(core_api=core_api, apps_api=apps_api) + logger.info(f"list_deployments took {time.time() - loop_start} seconds") + celery_queues = set() + celery_queues_params = [] + for deployment_and_namespace, params in sorted( + deployments.items() + ): # sort for a bit more determinism + # Hash the deployment / namespace to deterministically partition the deployments. + # Skip all deployments not in this partition. + if _hash_any_to_int(deployment_and_namespace) % num_shards != instance_count: + continue + + deployment_name, namespace = deployment_and_namespace + instance = instances.get(deployment_and_namespace) + if instance is None or instance.params != params: + instances[deployment_and_namespace] = Instance( + apps_api, deployment_name, namespace, params, env + ) + + # We're treating a queue as a pair consisting of a (queue_name, db_index). + # This means that two queues that happen to have the same name are treated + # as semantically distinct if they have different db_indexes. + celery_queues.add((params.queue, params.task_visibility.value)) + celery_queues_params.append(params.__dict__) + + # Clean up instances not in set + for deployment_and_namespace in set(instances) - set(deployments): + del instances[deployment_and_namespace] + + # Get queue sizes + # (queue_name, db_index) -> QueueSizes + start_get_metrics = time.time() + metrics = await get_metrics(broker, inspect=inspect, queues=celery_queues) + logger.info(f"get_metrics took {time.time() - start_get_metrics} seconds") + + queue_sizes = metrics.broker_metrics.queue_sizes + for k, v in sorted(queue_sizes.items()): + queue_name, _ = k + logger.info(f"Inflight : {queue_name} : {v.total}") + + emit_metrics(metrics=metrics, env=env) + + # Update scaling + for instance in instances.values(): + queue_size = queue_sizes[ + (instance.params.queue, int(instance.params.task_visibility)) + ] + try: + await instance.check_queue_size_and_update_deployment(queue_size.total) + except Exception as e: + logger.exception(f"Failed to update {instance.name}: {e}") + + # Wait before next iteration + iteration_len = time.time() - loop_start + logger.info(f"Iteration length: {iteration_len} seconds.") + if iteration_len < 3: + await aio.sleep(3 - iteration_len) + + emit_health_metric("heartbeat", env) + except Exception as e: + logger.exception(f"Error in deployment loop: {e}") + continue + + +if __name__ == "__main__": + aio.run(main()) diff --git a/server/llm_engine_server/core/celery/s3.py b/model-engine/model_engine_server/core/celery/s3.py similarity index 100% rename from server/llm_engine_server/core/celery/s3.py rename to model-engine/model_engine_server/core/celery/s3.py diff --git a/model-engine/model_engine_server/core/config.py b/model-engine/model_engine_server/core/config.py new file mode 100644 index 00000000..0474976c --- /dev/null +++ b/model-engine/model_engine_server/core/config.py @@ -0,0 +1,114 @@ +"""AWS configuration for ml-infra-services. + +The configuration file is loaded from the ML_INFRA_SERVICES_CONFIG_PATH environment variable. +If this is not set, the default configuration file is used from +model_engine_server.core/configs/default.yaml. +""" + +import inspect +import os +from contextlib import contextmanager +from copy import deepcopy +from dataclasses import dataclass +from pathlib import Path +from typing import Optional, Sequence + +import yaml +from model_engine_server.core.loggers import logger_name, make_logger + +logger = make_logger(logger_name()) + +__all__: Sequence[str] = ( + "DEFAULT_CONFIG_PATH", + "CONFIG_PATH", + "config_context", + "get_config_path_for_env_name", + "infra_config", + "use_config_context", +) + +DEFAULT_CONFIG_PATH = Path(__file__).parent / "configs" / "default.yaml" +CONFIG_PATH: str = os.getenv("ML_INFRA_SERVICES_CONFIG_PATH", str(DEFAULT_CONFIG_PATH)) + + +@dataclass +class _InfraConfig: + cloud_provider: str + env: str + k8s_cluster_name: str + dns_host_domain: str + default_region: str + ml_account_id: str + docker_repo_prefix: str + s3_bucket: str + redis_host: Optional[str] = None + redis_aws_secret_name: Optional[str] = None + profile_ml_worker: str = "default" + profile_ml_inference_worker: str = "default" + identity_service_url: Optional[str] = None + firehose_role_arn: Optional[str] = None + firehose_stream_name: Optional[str] = None + prometheus_server_address: Optional[str] = None + + +@dataclass +class DBEngineConfig: + db_engine_pool_size: int = 10 + db_engine_max_overflow: int = 10 + db_engine_echo: bool = False + db_engine_echo_pool: bool = False + db_engine_disconnect_strategy: str = "pessimistic" + + +@dataclass +class InfraConfig(DBEngineConfig, _InfraConfig): + @classmethod + def from_json(cls, json): + return cls(**{k: v for k, v in json.items() if k in inspect.signature(cls).parameters}) + + @classmethod + def from_yaml(cls, yaml_path) -> "InfraConfig": + with open(yaml_path, "r") as f: + raw_data = yaml.safe_load(f) + return InfraConfig.from_json(raw_data) + + +def read_default_config(): + logger.info(f"Using config file path: `{CONFIG_PATH}`") + return InfraConfig.from_yaml(CONFIG_PATH) + + +_infra_config: Optional[InfraConfig] = None + + +def infra_config() -> InfraConfig: + global _infra_config + if _infra_config is None: + _infra_config = read_default_config() + return _infra_config + + +@contextmanager +def config_context(config_path: str): + """Context manager that temporarily changes the config file path.""" + global _infra_config + current_config = deepcopy(_infra_config) + try: + _infra_config = InfraConfig.from_yaml(config_path) + yield + finally: + _infra_config = current_config + + +def use_config_context(config_path: str): + """Use the config file at the given path.""" + global _infra_config + _infra_config = InfraConfig.from_yaml(config_path) + + +def get_config_path_for_env_name(env_name: str) -> Path: + path = DEFAULT_CONFIG_PATH.parent / f"{env_name}.yaml" + if not path.exists(): + print(path) + raise ValueError(f"Config file does not exist for env: {env_name}") + return path diff --git a/model-engine/model_engine_server/core/configmap.py b/model-engine/model_engine_server/core/configmap.py new file mode 100644 index 00000000..d3edb669 --- /dev/null +++ b/model-engine/model_engine_server/core/configmap.py @@ -0,0 +1,35 @@ +"""Read configmap from k8s.""" + +from typing import Dict + +from kubernetes_asyncio import client +from kubernetes_asyncio import config as kube_config +from kubernetes_asyncio.client.rest import ApiException +from kubernetes_asyncio.config.config_exception import ConfigException +from model_engine_server.common.config import hmi_config +from model_engine_server.core.loggers import logger_name, make_logger + +DEFAULT_NAMESPACE = "default" + +logger = make_logger(logger_name()) + + +async def read_config_map( + config_map_name: str, namespace: str = hmi_config.gateway_namespace +) -> Dict[str, str]: + try: + kube_config.load_incluster_config() + except ConfigException: + logger.info("No incluster kubernetes config, falling back to local") + await kube_config.load_kube_config() + + core_api = client.CoreV1Api() + + try: + config_map = await core_api.read_namespaced_config_map( + name=config_map_name, namespace=namespace + ) + return config_map.data + except ApiException as e: + logger.exception(f"Error reading configmap {config_map_name}") + raise e diff --git a/server/llm_engine_server/core/configs/circleci.yaml b/model-engine/model_engine_server/core/configs/default.yaml similarity index 62% rename from server/llm_engine_server/core/configs/circleci.yaml rename to model-engine/model_engine_server/core/configs/default.yaml index 2ef8183b..2e2e6ec0 100644 --- a/server/llm_engine_server/core/configs/circleci.yaml +++ b/model-engine/model_engine_server/core/configs/default.yaml @@ -1,3 +1,4 @@ +cloud_provider: "aws" env: "circleci" k8s_cluster_name: "minikube" dns_host_domain: "localhost" @@ -5,6 +6,11 @@ default_region: "us-west-2" ml_account_id: "000000000000" docker_repo_prefix: "000000000000.dkr.ecr.us-west-2.amazonaws.com" redis_host: "redis-message-broker-master.default" -s3_bucket: "scale-ml-circleci" +s3_bucket: "test-bucket" profile_ml_worker: "default" profile_ml_inference_worker: "default" +db_engine_pool_size: 10 +db_engine_max_overflow: 10 +db_engine_echo: false +db_engine_echo_pool: false +db_engine_disconnect_strategy: "pessimistic" diff --git a/server/llm_engine_server/db/models/common/__init__.py b/model-engine/model_engine_server/core/docker/__init__.py similarity index 100% rename from server/llm_engine_server/db/models/common/__init__.py rename to model-engine/model_engine_server/core/docker/__init__.py diff --git a/server/llm_engine_server/core/docker/docker_image.py b/model-engine/model_engine_server/core/docker/docker_image.py similarity index 88% rename from server/llm_engine_server/core/docker/docker_image.py rename to model-engine/model_engine_server/core/docker/docker_image.py index b46065a0..8d68f8c8 100644 --- a/server/llm_engine_server/core/docker/docker_image.py +++ b/model-engine/model_engine_server/core/docker/docker_image.py @@ -5,7 +5,6 @@ """ import base64 -import logging import os import pathlib import subprocess @@ -16,16 +15,12 @@ import boto3 import click import docker -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import make_logger +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger from .remote_build import MODELS_ROOT, build_remote_wrapper -logger = make_logger("llm_engine_server.core.docker.docker_image", log_level=logging.INFO) - -REGISTRY_ID = ml_infra_config().ml_account_id -ECR_REGION = ml_infra_config().default_region -ECR_REPO = f"{REGISTRY_ID}.dkr.ecr.{ECR_REGION}.amazonaws.com" +logger = make_logger(logger_name()) def _get_aws_creds() -> Dict[str, str]: @@ -108,7 +103,7 @@ def build( ) # Make sure not to do this after grabbing the AWS creds, so that we don't print them out. tag = _get_image_tag(image_tag) - image = f"{ECR_REPO}/{service_name}:{tag}" + image = f"{infra_config().docker_repo_prefix}/{service_name}:{tag}" local_args["image"] = image @@ -163,13 +158,13 @@ def build( command=test_command, volumes={ os.path.join(home_dir, ".aws"): { - "bind": "/root/.aws/config", + "bind": "/opt/.aws/config", "mode": "ro", } }, environment={ - "AWS_PROFILE": ml_infra_config().profile_ml_worker, - "AWS_CONFIG_FILE": "/root/.aws/config", + "AWS_PROFILE": infra_config().profile_ml_worker, + "AWS_CONFIG_FILE": "/opt/.aws/config", }, remove=True, ) @@ -190,14 +185,14 @@ def push(service_name: str, image_tag: Optional[str] = None) -> None: logger.info(f"push args: {local_args}") docker_client = docker.from_env() - ecr_client = boto3.client("ecr", region_name=ECR_REGION) - token = ecr_client.get_authorization_token(registryIds=[REGISTRY_ID]) + ecr_client = boto3.client("ecr", region_name=infra_config().default_region) + token = ecr_client.get_authorization_token(registryIds=[infra_config().ml_account_id]) username, password = ( base64.b64decode(token["authorizationData"][0]["authorizationToken"]).decode().split(":") ) output = docker_client.images.push( - repository=f"{ECR_REPO}/{service_name}", + repository=f"{infra_config().docker_repo_prefix}/{service_name}", tag=_get_image_tag(image_tag), auth_config={"username": username, "password": password}, stream=True, diff --git a/server/llm_engine_server/core/docker/ecr.py b/model-engine/model_engine_server/core/docker/ecr.py similarity index 73% rename from server/llm_engine_server/core/docker/ecr.py rename to model-engine/model_engine_server/core/docker/ecr.py index 1192a4d2..fcd324b9 100644 --- a/server/llm_engine_server/core/docker/ecr.py +++ b/model-engine/model_engine_server/core/docker/ecr.py @@ -1,18 +1,17 @@ from typing import Dict, List, Optional import boto3 -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.utils.git import tag +from model_engine_server.core.config import infra_config +from model_engine_server.core.utils.git import tag DEFAULT_FILTER = {"tagStatus": "TAGGED"} def repository_exists(repository_name: str): - ecr = boto3.client("ecr", region_name=ml_infra_config().default_region) + ecr = boto3.client("ecr", region_name=infra_config().default_region) try: response = ecr.describe_repositories( - registryId=ml_infra_config().ml_account_id, - repositoryNames=[repository_name], + registryId=infra_config().ml_account_id, repositoryNames=[repository_name] ) if response.get("repositories"): return True @@ -23,7 +22,7 @@ def repository_exists(repository_name: str): def batch_image_exists( *, - region_name: str = ml_infra_config().default_region, + region_name: str = infra_config().default_region, repository_name: str, image_tags: Optional[List[str]] = None, image_digests: Optional[List[str]] = None, @@ -45,7 +44,7 @@ def batch_image_exists( client = session.client("ecr", region_name=region_name) try: client.describe_images( - registryId=ml_infra_config().ml_account_id, + registryId=infra_config().ml_account_id, repositoryName=repository_name, imageIds=[ *[{"imageTag": t} for t in image_tags], @@ -61,7 +60,7 @@ def batch_image_exists( def image_exists( *, - region_name: str = ml_infra_config().default_region, + region_name: str = infra_config().default_region, repository_name: str, image_name: Optional[str] = None, image_tag: Optional[str] = None, @@ -88,13 +87,25 @@ def ecr_exists_for_repo(repo_name: str, image_tag: Optional[str] = None): """Check if image exists in ECR""" if image_tag is None: image_tag = tag() - ecr = boto3.client("ecr", region_name=ml_infra_config().default_region) + ecr = boto3.client("ecr", region_name=infra_config().default_region) try: ecr.describe_images( - registryId=ml_infra_config().ml_account_id, + registryId=infra_config().ml_account_id, repositoryName=repo_name, imageIds=[{"imageTag": image_tag}], ) return True except ecr.exceptions.ImageNotFoundException: return False + + +def get_latest_image_tag(repository_name: str): + ecr = boto3.client("ecr", region_name=infra_config().default_region) + images = ecr.describe_images( + registryId=infra_config().ml_account_id, + repositoryName=repository_name, + filter=DEFAULT_FILTER, + maxResults=1000, + )["imageDetails"] + latest_image = max(images, key=lambda image: image["imagePushedAt"]) + return latest_image["imageTags"][0] diff --git a/server/llm_engine_server/core/docker/kaniko_template.yaml b/model-engine/model_engine_server/core/docker/kaniko_template.yaml similarity index 90% rename from server/llm_engine_server/core/docker/kaniko_template.yaml rename to model-engine/model_engine_server/core/docker/kaniko_template.yaml index d87f89f0..4f842367 100644 --- a/server/llm_engine_server/core/docker/kaniko_template.yaml +++ b/model-engine/model_engine_server/core/docker/kaniko_template.yaml @@ -33,27 +33,29 @@ spec: - "--cache=$USE_CACHE" - "--cache-copy-layers=$USE_CACHE" - "--cache-run-layers=$USE_CACHE" - - "--cache-repo=000000000000.dkr.ecr.us-west-2.amazonaws.com/kaniko-cache" + - "--cache-repo=$CACHE_REPO" - "--cleanup" - "--snapshot-mode=redo" - "--use-new-run" - "--image-fs-extract-retry=5" - "--log-format=json" + - "--push-retry=2" # The --use-new-run flag should fix docker builds eating up a lot of memory and consequently oom/failing env: - name: AWS_REGION value: us-west-2 + # TODO we need to parametrize AWS_REGION volumeMounts: - name: pipconf mountPath: /kaniko/pip resources: requests: cpu: 3.5 - memory: 30Gi + memory: 90Gi ephemeral-storage: 80G limits: cpu: 3.5 - memory: 30Gi + memory: 90Gi ephemeral-storage: 80G volumes: - name: pipconf diff --git a/server/llm_engine_server/core/docker/kaniko_template_circleci.yaml b/model-engine/model_engine_server/core/docker/kaniko_template_circleci.yaml similarity index 100% rename from server/llm_engine_server/core/docker/kaniko_template_circleci.yaml rename to model-engine/model_engine_server/core/docker/kaniko_template_circleci.yaml diff --git a/server/llm_engine_server/core/docker/remote_build.py b/model-engine/model_engine_server/core/docker/remote_build.py similarity index 90% rename from server/llm_engine_server/core/docker/remote_build.py rename to model-engine/model_engine_server/core/docker/remote_build.py index 99881cb4..5b192064 100644 --- a/server/llm_engine_server/core/docker/remote_build.py +++ b/model-engine/model_engine_server/core/docker/remote_build.py @@ -20,13 +20,13 @@ from kubernetes import config as kube_config from kubernetes import watch from kubernetes.config.config_exception import ConfigException -from llm_engine_server.core.aws import storage_client -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.aws import storage_client +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger logger = make_logger(logger_name()) -S3_BUCKET = os.environ.get("S3_BUCKET", ml_infra_config().s3_bucket) +S3_BUCKET = os.environ.get("S3_BUCKET", infra_config().s3_bucket) SUB_BUCKET = "tmp/docker_contexts" # Adjust if either this file or kaniko_template.yaml moves! OWN_FILE_PATH = Path(__file__).resolve() @@ -48,6 +48,7 @@ class BuildResult: status: bool logs: str + job_name: str def zip_context( @@ -60,7 +61,7 @@ def zip_context( Takes a path to a folder, zips up the folder and sticks it into s3 :param s3_file_name: Bucket/file for context tar.gz, will upload to here - :param context: Path to context for dockerfile, relative to calling script, i.e. box_detection/ or scaleml/ if you're running from models/ + :param context: Path to context for dockerfile, relative to calling script :param folders_to_include: List of paths to subfolders needed to build docker image, relative to context :param ignore_file: File (e.g. .dockerignore) containing things to ignore when preparing docker context. Relative to context. Contents of file are parsed according to tar's --exclude-from, which differs slightly from @@ -70,7 +71,8 @@ def zip_context( assert len(folders_to_include) > 0 assert s3_file_name.endswith(".gz") - print(f"Uploading to s3 at: {s3_file_name}") + s3_uri = f"s3://{S3_BUCKET}/{s3_file_name}" + print(f"Uploading to s3 at: {s3_uri}") try: # Need to gimme_okta_aws_creds (you can export AWS_PROFILE='ml-admin' right after) tar_command = _build_tar_cmd(context, ignore_file, folders_to_include) @@ -83,7 +85,7 @@ def zip_context( ) as proc: assert proc.stdout is not None with storage_client.open( - f"s3://{S3_BUCKET}/{s3_file_name}", + s3_uri, "wb", ) as out_file: shutil.copyfileobj(proc.stdout, out_file) @@ -122,6 +124,7 @@ def start_build_job( path_to_dockerfile: str, repotags: Iterable[str], use_cache: bool, + cache_name: str, build_args: Optional[Dict[str, str]] = None, custom_tags: Optional[Dict[str, str]] = None, ) -> str: @@ -142,8 +145,7 @@ def start_build_job( custom_tags_serialized = json.dumps(custom_tags) destination_template = Template( - f"--destination={ml_infra_config().ml_account_id}.dkr.ecr." - f"{ml_infra_config().default_region}.amazonaws.com/$REPO_AND_TAG" + f"--destination={infra_config().docker_repo_prefix}/$REPO_AND_TAG" ) job_name = f"kaniko-{str(uuid.uuid4())[:8]}" @@ -157,15 +159,11 @@ def start_build_job( aws_secret_access_key = "" if os.getenv("CIRCLECI"): aws_access_key_id_result = subprocess.run( - ["aws", "configure", "get", "aws_access_key_id"], - check=False, - stdout=PIPE, + ["aws", "configure", "get", "aws_access_key_id"], check=False, stdout=PIPE ) aws_access_key_id = aws_access_key_id_result.stdout.decode().strip() aws_secret_access_key_result = subprocess.run( - ["aws", "configure", "get", "aws_secret_access_key"], - check=False, - stdout=PIPE, + ["aws", "configure", "get", "aws_secret_access_key"], check=False, stdout=PIPE ) aws_secret_access_key = aws_secret_access_key_result.stdout.decode().strip() job = Template(template_f.read()).substitute( @@ -175,6 +173,7 @@ def start_build_job( S3_BUCKET=S3_BUCKET, S3_FILE=s3_file_name, USE_CACHE="true" if use_cache else "false", + CACHE_REPO=f"{infra_config().docker_repo_prefix}/{cache_name}", AWS_ACCESS_KEY_ID=aws_access_key_id, AWS_SECRET_ACCESS_KEY=aws_secret_access_key, NAMESPACE=NAMESPACE, @@ -197,25 +196,25 @@ def start_build_job( if not os.path.exists("/tmp"): os.makedirs("/tmp") pip_conf_file = "/tmp/.codeartifact-pip-conf" - aws_profile = ml_infra_config().profile_ml_worker - subprocess.check_output( - [ - f"AWS_PROFILE={aws_profile} python scripts_py3/scale_scripts/exe/maybe_refresh_codeartifact.py --export {pip_conf_file}" - ], - cwd=str(MODELS_ROOT), - shell=True, - ) - with open(pip_conf_file) as f_conf: - pip_conf_base64 = b64encode(f_conf.read().encode("utf-8")).decode("utf-8") + aws_profile = infra_config().profile_ml_worker + try: + # nosemgrep + subprocess.check_output( + [ + f"AWS_PROFILE={aws_profile} python scripts_py3/scale_scripts/exe/maybe_refresh_codeartifact.py --export {pip_conf_file}" + ], + cwd=str(MODELS_ROOT), + shell=True, + ) + with open(pip_conf_file) as f_conf: + pip_conf_data = f_conf.read() + except (subprocess.CalledProcessError, FileNotFoundError): + print("WARNING: Failed to refresh CodeArtifact token secret, using empty secret") + pip_conf_data = "" + pip_conf_base64 = b64encode(pip_conf_data.encode("utf-8")).decode("utf-8") data = {"data": {"codeartifact_pip_conf": pip_conf_base64}} subprocess.check_output( - [ - "kubectl", - "patch", - "secret", - "codeartifact-pip-conf", - f"-p={json.dumps(data)}", - ] + ["kubectl", "patch", "secret", "codeartifact-pip-conf", f"-p={json.dumps(data)}"] ).decode("utf-8") print(f"Executing Kaniko build command:\n{container_spec}") @@ -231,6 +230,7 @@ def build_remote( repotags: Union[str, Iterable[str]], folders_to_include: Optional[List[str]] = None, use_cache: bool = True, + cache_name: str = "kaniko-cache", ignore_file: Optional[str] = None, build_args: Optional[Dict[str, str]] = None, custom_tags: Optional[Dict[str, str]] = None, @@ -262,7 +262,7 @@ def build_remote( calling_path = Path(context).resolve() if folders_to_include is None: if calling_path == MODELS_ROOT: - default_folders = {"scaleml/"} + default_folders = {} # find the models/ project folder that this Dockerfile comes from parts = dockerfile.split("/") @@ -292,7 +292,9 @@ def build_remote( folders_to_include=folders_to_include, ignore_file=ignore_file, ) - return start_build_job(s3_file_name, dockerfile, repotags, use_cache, build_args, custom_tags) + return start_build_job( + s3_file_name, dockerfile, repotags, use_cache, cache_name, build_args, custom_tags + ) def verify_and_reformat_as_relative_to(context: str, dockerfile: str) -> str: @@ -407,13 +409,13 @@ def cleanup_logs_process(): ) elif event["object"].status.phase == "Succeeded": cleanup_logs_process() - return BuildResult(status=True, logs=_read_pod_logs(pod_name)) + return BuildResult(status=True, logs=_read_pod_logs(pod_name), job_name=job_name) elif event["object"].status.phase == "Failed": cleanup_logs_process() - return BuildResult(status=False, logs=_read_pod_logs(pod_name)) + return BuildResult(status=False, logs=_read_pod_logs(pod_name), job_name=job_name) if logs_process is not None: logs_process.kill() - return BuildResult(status=False, logs=_read_pod_logs(pod_name)) + return BuildResult(status=False, logs=_read_pod_logs(pod_name), job_name=job_name) def build_remote_block( @@ -422,6 +424,7 @@ def build_remote_block( repotags: Union[str, Iterable[str]], folders_to_include: Optional[List[str]] = None, use_cache: bool = True, + cache_name: str = "kaniko-cache", ignore_file: Optional[str] = None, build_args: Optional[Dict[str, str]] = None, custom_tags: Optional[Dict[str, str]] = None, @@ -439,16 +442,19 @@ def build_remote_block( :param ignore_file: File (e.g. .dockerignore) containing things to ignore when preparing docker context. Relative to context :return: BuildResult representing if docker image has successfully built/pushed """ + logger.info(f"build_remote_block args {locals()}") job_name = build_remote( context, dockerfile, repotags, folders_to_include, use_cache, + cache_name, ignore_file, build_args, custom_tags, ) + logger.info(f"Waiting for job {job_name} to finish") result = get_pod_status_and_log(job_name) return result @@ -477,7 +483,7 @@ def build_remote_block( @click.option( "--folders", required=False, - help="Comma separated list of folders (relative to context), e.g. 'scaleml/,template-project/", + help="Comma separated list of folders (relative to context", ) @click.option( "--no-cache", @@ -528,6 +534,8 @@ def build_remote_wrapper( custom_tags = json.loads(custom_tags) folders_to_include: Optional[List[str]] = folders.split(",") if folders is not None else None + cache_name = "kaniko-cache" + build_args = None if build_arg: build_arg_kvs = [arg.split("=") for arg in build_arg] @@ -540,6 +548,7 @@ def build_remote_wrapper( repotags=repotag, folders_to_include=folders_to_include, use_cache=not no_cache, + cache_name=cache_name, ignore_file=dockerignore, build_args=build_args, custom_tags=custom_tags, @@ -551,6 +560,7 @@ def build_remote_wrapper( repotags=repotag, folders_to_include=folders_to_include, use_cache=not no_cache, + cache_name=cache_name, ignore_file=dockerignore, build_args=build_args, custom_tags=custom_tags, diff --git a/server/llm_engine_server/core/fake_notification_gateway.py b/model-engine/model_engine_server/core/fake_notification_gateway.py similarity index 90% rename from server/llm_engine_server/core/fake_notification_gateway.py rename to model-engine/model_engine_server/core/fake_notification_gateway.py index 1b2ed19d..909c3037 100644 --- a/server/llm_engine_server/core/fake_notification_gateway.py +++ b/model-engine/model_engine_server/core/fake_notification_gateway.py @@ -1,7 +1,7 @@ from collections import defaultdict from typing import List -from llm_engine_server.core.notification_gateway import NotificationApp, NotificationGateway +from model_engine_server.core.notification_gateway import NotificationApp, NotificationGateway class FakeNotificationGateway(NotificationGateway): diff --git a/server/llm_engine_server/core/loggers.py b/model-engine/model_engine_server/core/loggers.py similarity index 83% rename from server/llm_engine_server/core/loggers.py rename to model-engine/model_engine_server/core/loggers.py index 5a0877a1..3a28d450 100644 --- a/server/llm_engine_server/core/loggers.py +++ b/model-engine/model_engine_server/core/loggers.py @@ -5,19 +5,18 @@ import sys import warnings from contextlib import contextmanager -from typing import Optional, Sequence +from enum import Enum +from typing import Dict, Optional, Sequence import ddtrace import json_log_formatter import tqdm -from ddtrace.helpers import get_correlation_ids +from ddtrace import tracer # DO NOT CHANGE LOGGING FORMAT LOG_FORMAT: str = "%(asctime)s %(levelname)s [%(name)s] [%(filename)s:%(lineno)d] - %(message)s" # REQUIRED FOR DATADOG COMPATIBILITY -ctx_var_request_id = contextvars.ContextVar("ctx_var_request_id", default=None) - __all__: Sequence[str] = ( # most common imports "make_logger", @@ -34,20 +33,38 @@ "silence_chatty_logger", "loggers_at_level", # utils - "filename_wo_ext", - "get_request_id", - "set_request_id", + "LoggerTagKey", + "LoggerTagManager", ) -def get_request_id() -> Optional[str]: - """Get the request id from the context variable.""" - return ctx_var_request_id.get() +class LoggerTagKey(str, Enum): + REQUEST_ID = "request_id" + TEAM_ID = "team_id" + USER_ID = "user_id" + REQUEST_SIZE = "request_size" + +class LoggerTagManager: + _context_vars: Dict[LoggerTagKey, contextvars.ContextVar] = {} -def set_request_id(request_id: str) -> None: - """Set the request id in the context variable.""" - ctx_var_request_id.set(request_id) + @classmethod + def get(cls, key: LoggerTagKey) -> Optional[str]: + """Get the value from the context variable.""" + ctx_var = cls._context_vars.get(key) + if ctx_var is not None: + return ctx_var.get() + return None + + @classmethod + def set(cls, key: LoggerTagKey, value: Optional[str]) -> None: + """Set the value in the context variable.""" + if value is not None: + ctx_var = cls._context_vars.get(key) + if ctx_var is None: + ctx_var = contextvars.ContextVar(f"ctx_var_{key.name.lower()}", default=None) + cls._context_vars[key] = ctx_var + ctx_var.set(value) def make_standard_logger(name: str, log_level: int = logging.INFO) -> logging.Logger: @@ -77,16 +94,15 @@ def json_record(self, message: str, extra: dict, record: logging.LogRecord) -> d extra["lineno"] = record.lineno extra["pathname"] = record.pathname - # add the http request id if it exists - request_id = ctx_var_request_id.get() - if request_id: - extra["request_id"] = request_id - - trace_id, span_id = get_correlation_ids() + # add additional logger tags + for tag_key in LoggerTagKey: + tag_value = LoggerTagManager.get(tag_key) + if tag_value: + extra[tag_key.value] = tag_value - # add ids to event dictionary - extra["dd.trace_id"] = trace_id or 0 - extra["dd.span_id"] = span_id or 0 + current_span = tracer.current_span() + extra["dd.trace_id"] = current_span.trace_id if current_span else 0 + extra["dd.span_id"] = current_span.span_id if current_span else 0 # add the env, service, and version configured for the tracer. # If tracing is not set up, then this should pull values from DD_ENV, DD_SERVICE, and DD_VERSION. @@ -120,19 +136,6 @@ def make_json_logger(name: str, log_level: int = logging.INFO) -> logging.Logger stream_handler = logging.StreamHandler() in_kubernetes = os.getenv("KUBERNETES_SERVICE_HOST") if in_kubernetes: - # Somewhat hacky way of determining if we're running in a Datadog environment. - # Note that if you 'kubectl logs' the pod, you'll still see the JSON logs. But you really should - # just be looking at the logs in Datadog at that point. - # - # NOTE: If you're thinking of disabling this outside of your local machine, please consider - # just piping to `jq` instead, e.g.: - # - # $ kubectl logs -lapp=celery-autoscaler-singleton | jq -r '[.time, .level, .message] | join(" - ")' - # - # this spits out: - # - # 2021-04-08T23:40:03.148308 - INFO - Missing params, skipping deployment : - # 2021-04-08T23:40:03.148440 - INFO - Missing params, skipping deployment : stream_handler.setFormatter(CustomJSONFormatter()) else: # Reading JSON logs in your terminal is kinda hard, and you can't make use of the structured data @@ -199,7 +202,7 @@ def logger_name(*, fallback_name: Optional[str] = None) -> str: # in which case we use it's file name if hasattr(calling_module, "__file__"): - return filename_wo_ext(calling_module.__file__) + return _filename_wo_ext(calling_module.__file__) # type: ignore if fallback_name is not None: fallback_name = fallback_name.strip() if len(fallback_name) > 0: @@ -246,8 +249,8 @@ def silence_chatty_logger(*logger_names, quieter=logging.FATAL) -> None: Accepts a variable number of logger names. """ - for name in logger_names: - log = logging.getLogger(name) + for logger_name in logger_names: + log = logging.getLogger(logger_name) log.setLevel(quieter) @@ -271,8 +274,8 @@ def silence_chatty_datadog_loggers(*, silence_internal_writer: bool = False) -> silence_chatty_logger("ddtrace.internal.writer", quieter=logging.FATAL) -@contextmanager -def loggers_at_level(*loggers_or_names, new_level: int) -> None: +@contextmanager # type: ignore +def loggers_at_level(*loggers_or_names, new_level: int) -> None: # type: ignore """Temporarily set one or more loggers to a specific level, resetting to previous levels on context end. :param:`loggers_or_names` is one or more :class:`logging.Logger` instances, or `str` names @@ -282,19 +285,19 @@ def loggers_at_level(*loggers_or_names, new_level: int) -> None: To illustrate use, see this pseudocode example: >>>> import logging - >>>> from llm_engine_server.core.loggers import loggers_at_level, make_logger + >>>> from model_engine_server.core.loggers import loggers_at_level, make_logger >>>> >>>> your_logger = make_logger('your_logger') >>>> >>>> with loggers_at_level( >>>> your_logger, - >>>> 'llm_engine_server.core.loggers', + >>>> 'model_engine_server.core.loggers', >>>> 'document_core.utils.k8s', >>>> new_level=logging.FATAL, >>>> ): >>>> # do_something_while_those_loggers_will_only_log_FATAL_messages >>>> your_logger.info("this will not be logged") - >>>> logging.getLogger('llm_engine_server.core.loggers').warning("neither will this") + >>>> logging.getLogger('model_engine_server.core.loggers').warning("neither will this") >>>> >>>> your_logger.info("this will be logged") """ @@ -313,6 +316,6 @@ def loggers_at_level(*loggers_or_names, new_level: int) -> None: log.setLevel(level) -def filename_wo_ext(filename: str) -> str: +def _filename_wo_ext(filename: str) -> str: """Gets the filename, without the file extension, if present.""" return os.path.split(filename)[1].split(".", 1)[0] diff --git a/server/llm_engine_server/core/notification_gateway.py b/model-engine/model_engine_server/core/notification_gateway.py similarity index 100% rename from server/llm_engine_server/core/notification_gateway.py rename to model-engine/model_engine_server/core/notification_gateway.py diff --git a/server/llm_engine_server/db/models/utils/__init__.py b/model-engine/model_engine_server/core/utils/__init__.py similarity index 100% rename from server/llm_engine_server/db/models/utils/__init__.py rename to model-engine/model_engine_server/core/utils/__init__.py diff --git a/server/llm_engine_server/core/utils/env.py b/model-engine/model_engine_server/core/utils/env.py similarity index 99% rename from server/llm_engine_server/core/utils/env.py rename to model-engine/model_engine_server/core/utils/env.py index ef9a6a88..3eb87dd8 100644 --- a/server/llm_engine_server/core/utils/env.py +++ b/model-engine/model_engine_server/core/utils/env.py @@ -1,4 +1,5 @@ """Utilities for working with environment variables.""" + import os from typing import ContextManager, Dict, Optional, Sequence, Union diff --git a/server/llm_engine_server/core/utils/format.py b/model-engine/model_engine_server/core/utils/format.py similarity index 99% rename from server/llm_engine_server/core/utils/format.py rename to model-engine/model_engine_server/core/utils/format.py index 39e82a57..a26bd2c5 100644 --- a/server/llm_engine_server/core/utils/format.py +++ b/model-engine/model_engine_server/core/utils/format.py @@ -1,4 +1,5 @@ """Utilities for formatting and printing messages, especially for CLI programs.""" + import traceback from logging import Logger from typing import Any, List, Optional, Sequence, Tuple, Union diff --git a/server/llm_engine_server/core/utils/git.py b/model-engine/model_engine_server/core/utils/git.py similarity index 100% rename from server/llm_engine_server/core/utils/git.py rename to model-engine/model_engine_server/core/utils/git.py diff --git a/server/llm_engine_server/core/utils/python_utils.py b/model-engine/model_engine_server/core/utils/python_utils.py similarity index 95% rename from server/llm_engine_server/core/utils/python_utils.py rename to model-engine/model_engine_server/core/utils/python_utils.py index 2cfcd2f8..2925c7d9 100644 --- a/server/llm_engine_server/core/utils/python_utils.py +++ b/model-engine/model_engine_server/core/utils/python_utils.py @@ -1,9 +1,10 @@ """Python-language-based utility functions.""" + import builtins from importlib import import_module from typing import Any, Optional -from llm_engine_server.core.utils.format import split_module_value, strip_non_empty +from model_engine_server.core.utils.format import split_module_value, strip_non_empty def dynamic_load(module_name: str, value_name: Optional[str], validate: bool = True) -> Any: diff --git a/server/llm_engine_server/core/utils/timer.py b/model-engine/model_engine_server/core/utils/timer.py similarity index 86% rename from server/llm_engine_server/core/utils/timer.py rename to model-engine/model_engine_server/core/utils/timer.py index 4f72b7a4..edd80891 100644 --- a/server/llm_engine_server/core/utils/timer.py +++ b/model-engine/model_engine_server/core/utils/timer.py @@ -1,4 +1,5 @@ """Utilities for timing code blocks.""" + import inspect import time from datetime import timedelta @@ -26,15 +27,14 @@ class timer: # pylint: disable=invalid-name The other use case is to pass in a `name` and a `logger`. The timing will be recorded when the context block is exited: - >>> from llm_engine_server.core.loggers import make_logger - >>> - >>> log = make_logger("my-main-program") + >>> from model_engine_server.core.loggers import make_logger, logger_name >>> + >>> log = make_logger(logger_name()) >>> >>> with timer(logger=log, name="timing-func-f"): >>> f() """ - __slots__ = ("logger", "name", "_duration", "start") + __slots__ = ("logger", "name", "_duration", "start", "start_lap") def __init__(self, logger: Optional[Logger] = None, name: str = "") -> None: self.logger = logger @@ -43,6 +43,7 @@ def __init__(self, logger: Optional[Logger] = None, name: str = "") -> None: # for start, -1 is the uninitialized value # it is set at the context-block entering method: __enter__ self.start: float = -1.0 + self.start_lap: float = -1.0 def __enter__(self) -> "timer": """Records start time: context-block entering function.""" @@ -63,6 +64,18 @@ def __exit__(self, *args) -> None: ) self._maybe_log_end_time() + def lap(self) -> float: + # Records a "lap time". Specifically if start is called at t_0, and lap is + # called at t_1 and t_2, then the returned values are t_1 - t_0 and t_2 - t_1. + # This does introduce extra overhead, however. + current_time = time.monotonic() + if self.start_lap == -1: + duration = current_time - self.start + else: + duration = current_time - self.start_lap + self.start_lap = current_time + return duration + def _maybe_log_end_time(self) -> None: if self.logger is not None: caller_namespace = "" diff --git a/server/llm_engine_server/core/utils/url.py b/model-engine/model_engine_server/core/utils/url.py similarity index 76% rename from server/llm_engine_server/core/utils/url.py rename to model-engine/model_engine_server/core/utils/url.py index e9d6c758..16ae3d6f 100644 --- a/server/llm_engine_server/core/utils/url.py +++ b/model-engine/model_engine_server/core/utils/url.py @@ -1,4 +1,5 @@ """URL-based utility functions.""" + import re from typing import NamedTuple, Optional @@ -8,6 +9,7 @@ class ParsedURL(NamedTuple): bucket: str key: str region: Optional[str] + account: Optional[str] = None def canonical_url(self) -> str: """Packs the parsed URL information into a standard form of @@ -23,6 +25,10 @@ def s3(bucket: str, key: str, region: Optional[str] = None) -> "ParsedURL": def gs(bucket: str, key: str, region: Optional[str] = None) -> "ParsedURL": return ParsedURL(protocol="gs", bucket=bucket, key=key, region=region) + @staticmethod + def azure(bucket: str, key: str, account: Optional[str] = None) -> "ParsedURL": + return ParsedURL(protocol="azure", bucket=bucket, key=key, account=account) + @staticmethod def cds(bucket: str, key: str, region: Optional[str] = None) -> "ParsedURL": return ParsedURL(protocol="scale-cds", bucket=bucket, key=key, region=region) @@ -32,7 +38,7 @@ class InvalidAttachmentUrl(ValueError): pass -def parse_attachment_url(url: str) -> ParsedURL: +def parse_attachment_url(url: str, clean_key: bool = True) -> ParsedURL: """Extracts protocol, bucket, region, and key from the :param:`url`. :raises: InvalidAttachmentUrl Iff the input `url` is not a valid AWS S3 or GCS url. @@ -42,6 +48,7 @@ def parse_attachment_url(url: str) -> ParsedURL: bucket = None region = None key = None + account = None # s3://bucket/key1/key2 match = re.search("^s3://([^/]+)/(.*?)$", url) @@ -54,6 +61,13 @@ def parse_attachment_url(url: str) -> ParsedURL: protocol = "gs" bucket, key = match.group(1), match.group(2) + # azure://bucket/key1/key2 + # for Azure Blob Storage, bucket refers to an ABS container + match = re.search("^azure://([^/]+)/(.*?)$", url) + if match: + protocol = "azure" + bucket, key = match.group(1), match.group(2) + # http://bucket.s3.amazonaws.com/key1/key2 match = re.search("^https?://(.+).s3.amazonaws.com(.*?)$", url) if match: @@ -85,9 +99,13 @@ def parse_attachment_url(url: str) -> ParsedURL: if match: bucket, key = match.group(1), match.group(2) - # pattern from https://docs.google.com/document/d/1WLbQXkQL7PLo0rkjU0RsI4SPAqUvV0WV1-FWkzicduc/edit - # scale-cds://62f2a2942a57fb0024e4dc3e/dgb6etBCrUHtOMQ#s3/scale-cds-private-us-west-2 - # scale-cds://57743957186fd0060017f1a1/json/0e09cdfc-adbb-4d88-acf7-d75a478328e3 + # https://account.blob.core.windows.net/bucket/key1/key2 + # for Azure Blob Storage, bucket refers to an ABS container + match = re.search("^https?://([^/]+).blob.core.windows.net/([^/]+)(.*?)$", url) + if match: + protocol = "azure" + account, bucket, key = match.group(1), match.group(2), match.group(3) + match = re.search("scale-cds://(\\w+)/([\\-\\w\\/]+)", url) if match: bucket, key = match.group(1), match.group(2) @@ -98,12 +116,13 @@ def parse_attachment_url(url: str) -> ParsedURL: "Invalid attachment URL: no bucket or key specified: \n" f"'{url}'" ) - def clean(val): - return val and val.strip("/") + def clean(v): + return v and v.strip("/") return ParsedURL( protocol=clean(protocol), bucket=clean(bucket), region=clean(region), - key=clean(key), + key=clean(key) if clean_key else key, + account=clean(account), ) diff --git a/server/llm_engine_server/domain/__init__.py b/model-engine/model_engine_server/db/__init__.py similarity index 100% rename from server/llm_engine_server/domain/__init__.py rename to model-engine/model_engine_server/db/__init__.py diff --git a/model-engine/model_engine_server/db/base.py b/model-engine/model_engine_server/db/base.py new file mode 100644 index 00000000..5033d8ad --- /dev/null +++ b/model-engine/model_engine_server/db/base.py @@ -0,0 +1,339 @@ +import os +import sys +import time +from dataclasses import dataclass +from typing import Iterator, Optional + +import sqlalchemy +from azure.identity import DefaultAzureCredential +from azure.keyvault.secrets import SecretClient +from model_engine_server.core.aws.secrets import get_key_file +from model_engine_server.core.config import InfraConfig, infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from sqlalchemy import Engine, create_engine +from sqlalchemy.ext.asyncio import AsyncEngine, async_sessionmaker, create_async_engine +from sqlalchemy.ext.declarative import declarative_base +from sqlalchemy.orm import sessionmaker +from sqlalchemy.pool import NullPool + +logger = make_logger(logger_name()) + + +def get_key_file_name(environment: str) -> str: + if infra_config().cloud_provider == "azure": + return f"{environment}-ml-infra-pg".replace("training", "prod").replace("-new", "") + return f"{environment}/ml_infra_pg".replace("training", "prod").replace("-new", "") + + +@dataclass +class DBConnection: + url: str + expiry_in_sec: Optional[int] = None + + +def get_engine_url( + env: Optional[str] = None, + read_only: bool = True, + sync: bool = True, +) -> DBConnection: + """Gets the URL of the Postgresql engine depending on the environment.""" + expiry_in_sec: Optional[int] = None + if os.getenv("ML_INFRA_DATABASE_URL"): + # In CircleCI environment, we set up a test in another container and specify the URL. + engine_url = os.getenv("ML_INFRA_DATABASE_URL") + elif "pytest" in sys.modules: + # If we are in a local testing environment, we can set up a test psql instance. + # pylint: disable=import-outside-toplevel + import testing.postgresql + + Postgresql = testing.postgresql.PostgresqlFactory( + cache_initialized_db=True, + ) + postgresql = Postgresql() + engine_url = postgresql.url() + else: + key_file = os.environ.get("DB_SECRET_NAME") + if env is None: + env = infra_config().env + if key_file is None: + key_file = get_key_file_name(env) # type: ignore + logger.debug(f"Using key file {key_file}") + + if infra_config().cloud_provider == "azure": + client = SecretClient( + vault_url=f"https://{os.environ.get('KEYVAULT_NAME')}.vault.azure.net", + credential=DefaultAzureCredential(), + ) + db = client.get_secret(key_file).value + user = os.environ.get("AZURE_IDENTITY_NAME") + token = DefaultAzureCredential().get_token( + "https://ossrdbms-aad.database.windows.net/.default" + ) + password = token.token + logger.info(f"Connecting to db {db} as user {user}") + + # TODO: https://docs.sqlalchemy.org/en/20/core/engines.html#generating-dynamic-authentication-tokens + # for recommendations on how to work with rotating auth credentials + engine_url = f"postgresql://{user}:{password}@{db}?sslmode=require" + expiry_in_sec = token.expires_on + else: + db_secret_aws_profile = os.environ.get("DB_SECRET_AWS_PROFILE") + creds = get_key_file(key_file, db_secret_aws_profile) + + user = creds.get("username") + password = creds.get("password") + host = creds.get("clusterHostRo") if read_only else creds.get("clusterHost") + port = str(creds.get("port")) + dbname = creds.get("dbname") + logger.info(f"Connecting to db {host}:{port}, name {dbname}") + + engine_url = f"postgresql://{user}:{password}@{host}:{port}/{dbname}" + + assert engine_url + + # For async postgres, we need to use an async dialect. + if not sync: + engine_url = engine_url.replace("postgresql://", "postgresql+asyncpg://").replace( + "sslmode", "ssl" + ) + return DBConnection(engine_url, expiry_in_sec) + + +@dataclass +class SyncDBSession: + engine: Engine + session: sessionmaker + + +@dataclass +class AsyncDBSession: + engine: AsyncEngine + session: async_sessionmaker + + +@dataclass +class DBSessions: + session_sync: SyncDBSession + session_sync_ro: SyncDBSession + session_async: AsyncDBSession + session_async_ro: AsyncDBSession + session_async_null_pool: AsyncDBSession + + +@dataclass +class DBEngineConfig: + pool_pre_ping: bool + pool_size: int + max_overflow: int + echo: bool + echo_pool: bool + + +class DBManager: + sessions: DBSessions + config: DBEngineConfig + + credential_expiration_timestamp: Optional[float] = None + credential_expiration_buffer_sec: int = 300 + + def _get_engine_url(self, read_only: bool, sync: bool) -> DBConnection: + return get_engine_url(read_only=read_only, sync=sync) + + def __init__(self, infra_config: InfraConfig): + self.pool_pre_ping = infra_config.db_engine_disconnect_strategy == "pessimistic" + self.pool_size = infra_config.db_engine_pool_size + self.max_overflow = infra_config.db_engine_max_overflow + self.echo = infra_config.db_engine_echo + self.echo_pool = infra_config.db_engine_echo_pool + self.sessions = self.refresh_sessions() + + def refresh_sessions(self) -> DBSessions: + db_connection = get_engine_url(read_only=False, sync=True) + # use sync engine as proxy for credential expiration + self.credential_expiration_timestamp = db_connection.expiry_in_sec + pg_engine = create_engine( + db_connection.url, + echo=self.echo, + echo_pool=self.echo_pool, + pool_pre_ping=self.pool_pre_ping, + pool_size=self.pool_size, + max_overflow=self.max_overflow, + future=True, + logging_name="sync", + ) + session_sync = SyncDBSession( + engine=pg_engine, + session=sessionmaker(autocommit=False, autoflush=False, bind=pg_engine), + ) + + pg_engine_ro = create_engine( + url=get_engine_url(read_only=True, sync=True).url, + echo=self.echo, + echo_pool=self.echo_pool, + pool_pre_ping=self.pool_pre_ping, + pool_size=self.pool_size, + max_overflow=self.max_overflow, + future=True, + logging_name="sync_ro", + ) + session_sync_ro = SyncDBSession( + engine=pg_engine_ro, + session=sessionmaker(autocommit=False, autoflush=False, bind=pg_engine_ro), + ) + + pg_engine_async = create_async_engine( + url=get_engine_url(read_only=False, sync=False).url, + echo=self.echo, + echo_pool=self.echo_pool, + pool_pre_ping=self.pool_pre_ping, + pool_size=self.pool_size, + max_overflow=self.max_overflow, + future=True, + logging_name="async", + ) + session_async = AsyncDBSession( + engine=pg_engine_async, + session=async_sessionmaker( + autocommit=False, + autoflush=False, + bind=pg_engine_async, + expire_on_commit=False, + ), + ) + + pg_engine_async_ro = create_async_engine( + url=get_engine_url(read_only=True, sync=False).url, + echo=self.echo, + echo_pool=self.echo_pool, + pool_pre_ping=self.pool_pre_ping, + pool_size=self.pool_size, + max_overflow=self.max_overflow, + future=True, + logging_name="async_ro", + ) + session_async_ro = AsyncDBSession( + engine=pg_engine_async_ro, + session=async_sessionmaker( + autocommit=False, + autoflush=False, + bind=pg_engine_async_ro, + expire_on_commit=False, + ), + ) + + pg_engine_async_null_pool = create_async_engine( + url=get_engine_url(read_only=False, sync=False).url, + echo=self.echo, + echo_pool=self.echo_pool, + future=True, + poolclass=NullPool, + logging_name="async_null", + ) + + session_async_null_pool = AsyncDBSession( + engine=pg_engine_async_null_pool, + session=async_sessionmaker( + autocommit=False, + autoflush=False, + bind=pg_engine_async_null_pool, + expire_on_commit=False, + ), + ) + + return DBSessions( + session_sync=session_sync, + session_sync_ro=session_sync_ro, + session_async=session_async, + session_async_ro=session_async_ro, + session_async_null_pool=session_async_null_pool, + ) + + def _is_credentials_expired(self): + return ( + self.credential_expiration_timestamp is not None + and time.time() + > self.credential_expiration_timestamp - self.credential_expiration_buffer_sec + ) + + def _maybe_refresh_sessions(self): + if self._is_credentials_expired(): + old_sessions = self.sessions + self.sessions = self.refresh_sessions() + old_sessions.session_sync.engine.dispose() + old_sessions.session_sync_ro.engine.dispose() + old_sessions.session_async.engine.dispose() + old_sessions.session_async_ro.engine.dispose() + old_sessions.session_async_null_pool.engine.dispose() + + def get_session_sync(self) -> sessionmaker: + self._maybe_refresh_sessions() + return self.sessions.session_sync.session + + def get_session_sync_ro(self) -> sessionmaker: + self._maybe_refresh_sessions() + return self.sessions.session_sync_ro.session + + def get_session_async(self) -> async_sessionmaker: + self._maybe_refresh_sessions() + return self.sessions.session_async.session + + def get_session_async_ro(self) -> async_sessionmaker: + self._maybe_refresh_sessions() + return self.sessions.session_async_ro.session + + def get_session_async_null_pool(self) -> async_sessionmaker: + self._maybe_refresh_sessions() + return self.sessions.session_async_null_pool.session + + +db_manager: Optional[DBManager] = None + + +def get_db_manager(): + global db_manager + if db_manager is None: + db_manager = DBManager(infra_config()) + return db_manager + + +def get_session(): + return get_db_manager().get_session_sync() + + +def get_session_read_only(): + return get_db_manager().get_session_sync_ro() + + +def get_session_async(): + return get_db_manager().get_session_async() + + +def get_session_async_null_pool(): + return get_db_manager().get_session_async_null_pool() + + +def get_session_read_only_async(): + return get_db_manager().get_session_async_ro() + + +Base = declarative_base() + + +def get_session_iterator() -> Iterator[sqlalchemy.orm.Session]: + """Utility to return an iterator with an instantiated session in the ML Infra database.""" + Session = get_session() + session = Session() + try: + yield session + finally: + session.close() + + +def get_read_only_session_iterator() -> Iterator[sqlalchemy.orm.Session]: + """Utility to return an iterator with an instantiated session in the ML Infra database.""" + SessionReadOnly = get_session_read_only() + session = SessionReadOnly() + try: + yield session + finally: + session.close() diff --git a/server/llm_engine_server/db/endpoint_row_lock.py b/model-engine/model_engine_server/db/endpoint_row_lock.py similarity index 94% rename from server/llm_engine_server/db/endpoint_row_lock.py rename to model-engine/model_engine_server/db/endpoint_row_lock.py index ed5786ac..b3d0e307 100644 --- a/server/llm_engine_server/db/endpoint_row_lock.py +++ b/model-engine/model_engine_server/db/endpoint_row_lock.py @@ -4,12 +4,12 @@ import time from contextlib import AbstractContextManager -from llm_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.core.loggers import logger_name, make_logger from sqlalchemy import BIGINT, cast, func, select from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm.session import Session -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) BLOCKING_LOCK_TIMEOUT_SECONDS = 120 BLOCKING_LOCK_TIMEOUT_POLL_FREQ_SECONDS = 0.5 @@ -17,14 +17,10 @@ def get_lock_key(user_id: str, endpoint_name: str) -> int: uid_hash = int.from_bytes( - hashlib.sha256(bytes(user_id, "utf-8")).digest()[:4], - byteorder="little", - signed=False, + hashlib.sha256(bytes(user_id, "utf-8")).digest()[:4], byteorder="little", signed=False ) endpoint_name_hash = int.from_bytes( - hashlib.sha256(bytes(endpoint_name, "utf-8")).digest()[:4], - byteorder="little", - signed=False, + hashlib.sha256(bytes(endpoint_name, "utf-8")).digest()[:4], byteorder="little", signed=False ) return 2**32 * uid_hash + endpoint_name_hash - 2**63 diff --git a/server/llm_engine_server/db/local_setup.py b/model-engine/model_engine_server/db/local_setup.py similarity index 92% rename from server/llm_engine_server/db/local_setup.py rename to model-engine/model_engine_server/db/local_setup.py index cb446d5e..4db34463 100644 --- a/server/llm_engine_server/db/local_setup.py +++ b/model-engine/model_engine_server/db/local_setup.py @@ -2,13 +2,13 @@ import os import psycopg2 -from llm_engine_server.db.base import Base -from llm_engine_server.db.models import * +from model_engine_server.db.base import Base +from model_engine_server.db.models import * from sqlalchemy import create_engine from sqlalchemy.engine import Engine from tenacity import Retrying, stop_after_attempt, wait_exponential -SCHEMAS = ["llm_engine", "model"] +SCHEMAS = ["hosted_model_inference", "model"] def init_database(database_url: str, psycopg_connection): diff --git a/model-engine/model_engine_server/db/migrations/README b/model-engine/model_engine_server/db/migrations/README new file mode 100644 index 00000000..34a0d901 --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/README @@ -0,0 +1,43 @@ +# Setup + +We introduce alembic by +1. dumping the current db schemas into 'initial.sql' via pg_dump + +``` +pg_dump -h $HOST -U postgres -O -s -d $DB_NAME -n hosted_model_inference -n model -f initial.sql +``` + +2. writing an initial revision that reads and applies intial.sql script + +``` +alembic revision -m “initial” +``` + +3. Stamping the current revision to our production db to avoid actually running it on production + +``` +alembic stamp fa3267c80731 +``` + + +# Test db migration from scratch + +## Set up postgresql + +``` +docker pull postgres +docker run --name postgres -e POSTGRES_PASSWORD=password -d -p 5432:5432 postgres +``` + +## Run migration script + +``` +PYTHONPATH="${PYTHONPATH}:" +ML_INFRA_DATABASE_URL="postgresql://postgres:password@localhost:54320/postgres" bash run_database_migration.sh +``` + + +To reset db, you can recreate docker or run +``` +psql "$ML_INFRA_DATABASE_URL" -c "DROP table if exists public.alembic_version_model_engine; DROP schema if exists hosted_model_inference CASCADE; DROP schema if exists model CASCADE" +``` diff --git a/model-engine/model_engine_server/db/migrations/alembic.ini b/model-engine/model_engine_server/db/migrations/alembic.ini new file mode 100644 index 00000000..23f7c0ea --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/alembic.ini @@ -0,0 +1,105 @@ +# A generic, single database configuration. + +[alembic] +# path to migration scripts +script_location = alembic + +# template used to generate migration file names; The default value is %%(rev)s_%%(slug)s +# Uncomment the line below if you want the files to be prepended with date and time +# see https://alembic.sqlalchemy.org/en/latest/tutorial.html#editing-the-ini-file +# for all available tokens +file_template = %%(year)d_%%(month).2d_%%(day).2d_%%(hour).2d%%(minute).2d-%%(rev)s_%%(slug)s + +# sys.path path, will be prepended to sys.path if present. +# defaults to the current working directory. +prepend_sys_path = . + +# timezone to use when rendering the date within the migration file +# as well as the filename. +# If specified, requires the python-dateutil library that can be +# installed by adding `alembic[tz]` to the pip requirements +# string value is passed to dateutil.tz.gettz() +# leave blank for localtime +# timezone = + +# max length of characters to apply to the +# "slug" field +# truncate_slug_length = 40 + +# set to 'true' to run the environment during +# the 'revision' command, regardless of autogenerate +# revision_environment = false + +# set to 'true' to allow .pyc and .pyo files without +# a source .py file to be detected as revisions in the +# versions/ directory +sourceless = false + +# version location specification; This defaults +# to alembic/versions. When using multiple version +# directories, initial revisions must be specified with --version-path. +# The path separator used here should be the separator specified by "version_path_separator" below. +# version_locations = %(here)s/bar:%(here)s/bat:alembic/versions + +# version path separator; As mentioned above, this is the character used to split +# version_locations. The default within new alembic.ini files is "os", which uses os.pathsep. +# If this key is omitted entirely, it falls back to the legacy behavior of splitting on spaces and/or commas. +# Valid values for version_path_separator are: +# +# version_path_separator = : +# version_path_separator = ; +# version_path_separator = space +version_path_separator = os # Use os.pathsep. Default configuration used for new projects. + +# the output encoding used when revision files +# are written from script.py.mako +# output_encoding = utf-8 + +sqlalchemy.url = driver://user:pass@localhost/dbname + + +[post_write_hooks] +# post_write_hooks defines scripts or Python functions that are run +# on newly generated revision scripts. See the documentation for further +# detail and examples + +# format using "black" - use the console_scripts runner, against the "black" entrypoint +# hooks = black +# black.type = console_scripts +# black.entrypoint = black +# black.options = -l 79 REVISION_SCRIPT_FILENAME + +# Logging configuration +[loggers] +keys = root,sqlalchemy,alembic + +[handlers] +keys = console + +[formatters] +keys = generic + +[logger_root] +level = WARN +handlers = console +qualname = + +[logger_sqlalchemy] +level = WARN +handlers = +qualname = sqlalchemy.engine + +[logger_alembic] +level = INFO +handlers = +qualname = alembic + +[handler_console] +class = StreamHandler +args = (sys.stderr,) +level = NOTSET +formatter = generic + +[formatter_generic] +format = %(levelname)-5.5s [%(name)s] %(message)s +datefmt = %H:%M:%S diff --git a/server/llm_engine_server/db/migrations/alembic/env.py b/model-engine/model_engine_server/db/migrations/alembic/env.py similarity index 72% rename from server/llm_engine_server/db/migrations/alembic/env.py rename to model-engine/model_engine_server/db/migrations/alembic/env.py index c85751f9..6cb95d67 100644 --- a/server/llm_engine_server/db/migrations/alembic/env.py +++ b/model-engine/model_engine_server/db/migrations/alembic/env.py @@ -3,19 +3,27 @@ from logging.config import fileConfig from alembic import context -from llm_engine_server.db.base import get_engine_url +from model_engine_server.db.base import get_engine_url from sqlalchemy import engine_from_config, pool env = os.environ.get("ENV") -assert env is not None, "Expected ENV to be a nonempty environment variable." +# if env is None: +# assert ( +# os.getenv("ML_INFRA_DATABASE_URL") is not None +# ), "Expected ML_INFRA_DATABASE_URL to be set if ENV is not set." +# this is the Alembic Config object, which provides +# access to the values within the .ini file in use. config = context.config -config.set_main_option("sqlalchemy.url", get_engine_url(env, read_only=False)) +config.set_main_option("sqlalchemy.url", get_engine_url(env, read_only=False).url) + +ALEMBIC_TABLE_NAME = "alembic_version_model_engine" # Interpret the config file for Python logging. # This line sets up loggers basically. -fileConfig(config.config_file_name) +if config.config_file_name is not None: + fileConfig(config.config_file_name) # add your model's MetaData object here # for 'autogenerate' support @@ -23,14 +31,13 @@ # target_metadata = mymodel.Base.metadata target_metadata = None - # other values from the config, defined by the needs of env.py, # can be acquired: # my_important_option = config.get_main_option("my_important_option") # ... etc. -def run_migrations_offline(): +def run_migrations_offline() -> None: """Run migrations in 'offline' mode. This configures the context with just a URL @@ -48,7 +55,8 @@ def run_migrations_offline(): target_metadata=target_metadata, literal_binds=True, dialect_opts={"paramstyle": "named"}, - include_schemas=True, + version_table=ALEMBIC_TABLE_NAME, + version_table_schema="public", ) try: @@ -59,7 +67,7 @@ def run_migrations_offline(): raise e -def run_migrations_online(): +def run_migrations_online() -> None: """Run migrations in 'online' mode. In this scenario we need to create an Engine @@ -73,7 +81,12 @@ def run_migrations_online(): ) with connectable.connect() as connection: - context.configure(connection=connection, target_metadata=target_metadata) + context.configure( + connection=connection, + target_metadata=target_metadata, + version_table=ALEMBIC_TABLE_NAME, + version_table_schema="public", + ) try: with context.begin_transaction(): diff --git a/server/llm_engine_server/db/migrations/alembic/script.py.mako b/model-engine/model_engine_server/db/migrations/alembic/script.py.mako similarity index 90% rename from server/llm_engine_server/db/migrations/alembic/script.py.mako rename to model-engine/model_engine_server/db/migrations/alembic/script.py.mako index 2c015630..55df2863 100644 --- a/server/llm_engine_server/db/migrations/alembic/script.py.mako +++ b/model-engine/model_engine_server/db/migrations/alembic/script.py.mako @@ -16,9 +16,9 @@ branch_labels = ${repr(branch_labels)} depends_on = ${repr(depends_on)} -def upgrade(): +def upgrade() -> None: ${upgrades if upgrades else "pass"} -def downgrade(): +def downgrade() -> None: ${downgrades if downgrades else "pass"} diff --git a/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1736-fa3267c80731_initial.py b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1736-fa3267c80731_initial.py new file mode 100644 index 00000000..efee8963 --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1736-fa3267c80731_initial.py @@ -0,0 +1,30 @@ +"""“initial” + +Revision ID: fa3267c80731 +Revises: +Create Date: 2024-09-09 17:36:30.097136 + +""" + +from pathlib import Path + +INITIAL_MIGRATION_PATH = Path(__file__).parent / "../../initial.sql" + + +import sqlalchemy as sa +from alembic import op + +# revision identifiers, used by Alembic. +revision = "fa3267c80731" +down_revision = None +branch_labels = None +depends_on = None + + +def upgrade() -> None: + with open(INITIAL_MIGRATION_PATH) as fd: + op.execute(fd.read()) + + +def downgrade() -> None: + pass diff --git a/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1831-b574e9711e35_chat_completion_add_extra_routes.py b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1831-b574e9711e35_chat_completion_add_extra_routes.py new file mode 100644 index 00000000..43279e0f --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_09_1831-b574e9711e35_chat_completion_add_extra_routes.py @@ -0,0 +1,32 @@ +"""chat completion - Add extra_routes + +Revision ID: b574e9711e35 +Revises: fa3267c80731 +Create Date: 2024-09-09 18:31:59.422082 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import ARRAY + +# revision identifiers, used by Alembic. +revision = "b574e9711e35" +down_revision = "fa3267c80731" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "bundles", + sa.Column("runnable_image_extra_routes", ARRAY(sa.Text), nullable=True), + schema="hosted_model_inference", + ) + + +def downgrade(): + op.drop_column( + "bundles", + "runnable_image_extra_routes", + schema="hosted_model_inference", + ) diff --git a/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_24_1456-f55525c81eb5_multinode_bundle.py b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_24_1456-f55525c81eb5_multinode_bundle.py new file mode 100644 index 00000000..532b0e38 --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/alembic/versions/2024_09_24_1456-f55525c81eb5_multinode_bundle.py @@ -0,0 +1,42 @@ +"""multinode_bundle + +Revision ID: f55525c81eb5 +Revises: b574e9711e35 +Create Date: 2024-09-24 14:56:36.287001 + +""" +import sqlalchemy as sa +from alembic import op +from sqlalchemy.dialects.postgresql import ARRAY + +# revision identifiers, used by Alembic. +revision = "f55525c81eb5" +down_revision = "b574e9711e35" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "bundles", + sa.Column("runnable_image_worker_command", ARRAY(sa.Text), nullable=True), + schema="hosted_model_inference", + ) + op.add_column( + "bundles", + sa.Column("runnable_image_worker_env", sa.JSON, nullable=True), + schema="hosted_model_inference", + ) + + +def downgrade() -> None: + op.drop_column( + "bundles", + "runnable_image_worker_command", + schema="hosted_model_inference", + ) + op.drop_column( + "bundles", + "runnable_image_worker_env", + schema="hosted_model_inference", + ) diff --git a/model-engine/model_engine_server/db/migrations/initial.sql b/model-engine/model_engine_server/db/migrations/initial.sql new file mode 100644 index 00000000..65728431 --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/initial.sql @@ -0,0 +1,617 @@ +-- +-- PostgreSQL database dump +-- + +-- Dumped from database version 13.12 +-- Dumped by pg_dump version 13.16 (Ubuntu 13.16-1.pgdg20.04+1) + +SET statement_timeout = 0; +SET lock_timeout = 0; +SET idle_in_transaction_session_timeout = 0; +SET client_encoding = 'UTF8'; +SET standard_conforming_strings = on; +SELECT pg_catalog.set_config('search_path', '', false); +SET check_function_bodies = false; +SET xmloption = content; +SET client_min_messages = warning; +SET row_security = off; + +-- +-- Name: hosted_model_inference; Type: SCHEMA; Schema: -; Owner: - +-- + +CREATE SCHEMA hosted_model_inference; + + +-- +-- Name: model; Type: SCHEMA; Schema: -; Owner: - +-- + +CREATE SCHEMA model; + + +SET default_tablespace = ''; + +SET default_table_access_method = heap; + +-- +-- Name: batch_jobs; Type: TABLE; Schema: hosted_model_inference; Owner: - +-- + +CREATE TABLE hosted_model_inference.batch_jobs ( + id text NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + completed_at timestamp with time zone, + status text NOT NULL, + created_by character varying(24) NOT NULL, + owner character varying(24) NOT NULL, + model_bundle_id text NOT NULL, + model_endpoint_id text, + task_ids_location text, + result_location text +); + + +-- +-- Name: bundles; Type: TABLE; Schema: hosted_model_inference; Owner: - +-- + +CREATE TABLE hosted_model_inference.bundles ( + id text NOT NULL, + name character varying(50), + created_by character varying(24), + created_at timestamp with time zone DEFAULT now(), + location text, + version character varying(24), + registered_model_name text, + bundle_metadata json, + requirements json, + env_params json, + packaging_type text, + app_config json, + model_artifact_ids text[] DEFAULT '{}'::text[], + schema_location text, + owner character varying(24) NOT NULL, + flavor text, + artifact_requirements text[], + artifact_app_config json, + artifact_framework_type text, + artifact_pytorch_image_tag text, + artifact_tensorflow_version text, + artifact_image_repository text, + artifact_image_tag text, + cloudpickle_artifact_load_predict_fn text, + cloudpickle_artifact_load_model_fn text, + zip_artifact_load_predict_fn_module_path text, + zip_artifact_load_model_fn_module_path text, + runnable_image_repository text, + runnable_image_tag text, + runnable_image_command text[], + runnable_image_env json, + runnable_image_protocol text, + artifact_location text, + runnable_image_readiness_initial_delay_seconds integer, + triton_enhanced_runnable_image_model_repository text, + triton_enhanced_runnable_image_model_replicas json, + triton_enhanced_runnable_image_num_cpu numeric, + triton_enhanced_runnable_image_commit_tag text, + triton_enhanced_runnable_image_storage text, + triton_enhanced_runnable_image_memory text, + triton_enhanced_runnable_image_readiness_initial_delay_seconds integer, + streaming_enhanced_runnable_image_streaming_command text[], + runnable_image_predict_route text, + streaming_enhanced_runnable_image_streaming_predict_route text, + runnable_image_healthcheck_route text, + CONSTRAINT bundles_flavor_0 CHECK ((flavor = ANY (ARRAY['cloudpickle_artifact'::text, 'zip_artifact'::text, 'runnable_image'::text, 'triton_enhanced_runnable_image'::text, 'streaming_enhanced_runnable_image'::text]))), + CONSTRAINT bundles_flavor_1 CHECK (((flavor ~~ '%_artifact'::text) = (artifact_requirements IS NOT NULL))), + CONSTRAINT bundles_flavor_10 CHECK (((flavor = 'zip_artifact'::text) = (zip_artifact_load_predict_fn_module_path IS NOT NULL))), + CONSTRAINT bundles_flavor_11 CHECK (((flavor = 'zip_artifact'::text) = (zip_artifact_load_model_fn_module_path IS NOT NULL))), + CONSTRAINT bundles_flavor_12 CHECK (((flavor ~~ '%runnable_image'::text) = (runnable_image_repository IS NOT NULL))), + CONSTRAINT bundles_flavor_13 CHECK (((flavor ~~ '%runnable_image'::text) = (runnable_image_tag IS NOT NULL))), + CONSTRAINT bundles_flavor_14 CHECK (((flavor ~~ '%runnable_image'::text) = (runnable_image_command IS NOT NULL))), + CONSTRAINT bundles_flavor_15 CHECK (((flavor ~~ '%runnable_image'::text) = (runnable_image_protocol IS NOT NULL))), + CONSTRAINT bundles_flavor_16 CHECK (((flavor = 'triton_enhanced_runnable_image'::text) = (triton_enhanced_runnable_image_model_repository IS NOT NULL))), + CONSTRAINT bundles_flavor_17 CHECK (((flavor = 'triton_enhanced_runnable_image'::text) = (triton_enhanced_runnable_image_num_cpu IS NOT NULL))), + CONSTRAINT bundles_flavor_18 CHECK (((flavor = 'triton_enhanced_runnable_image'::text) = (triton_enhanced_runnable_image_commit_tag IS NOT NULL))), + CONSTRAINT bundles_flavor_19 CHECK (((flavor = 'triton_enhanced_runnable_image'::text) = (triton_enhanced_runnable_image_readiness_initial_delay_seconds IS NOT NULL))), + CONSTRAINT bundles_flavor_2 CHECK (((flavor ~~ '%_artifact'::text) = (artifact_location IS NOT NULL))), + CONSTRAINT bundles_flavor_20 CHECK (((flavor = 'streaming_enhanced_runnable_image'::text) = (streaming_enhanced_runnable_image_streaming_command IS NOT NULL))), + CONSTRAINT bundles_flavor_21 CHECK (((flavor ~~ '%runnable_image'::text) = (runnable_image_predict_route IS NOT NULL))), + CONSTRAINT bundles_flavor_22 CHECK (((flavor ~~ '%runnable_image'::text) = (runnable_image_healthcheck_route IS NOT NULL))), + CONSTRAINT bundles_flavor_23 CHECK (((flavor = 'streaming_enhanced_runnable_image'::text) = (streaming_enhanced_runnable_image_streaming_predict_route IS NOT NULL))), + CONSTRAINT bundles_flavor_3 CHECK (((flavor ~~ '%_artifact'::text) = (artifact_framework_type IS NOT NULL))), + CONSTRAINT bundles_flavor_4 CHECK (((artifact_framework_type = 'pytorch'::text) = (artifact_pytorch_image_tag IS NOT NULL))), + CONSTRAINT bundles_flavor_5 CHECK (((artifact_framework_type = 'tensorflow'::text) = (artifact_tensorflow_version IS NOT NULL))), + CONSTRAINT bundles_flavor_6 CHECK (((artifact_framework_type = 'custom_base_image'::text) = (artifact_image_repository IS NOT NULL))), + CONSTRAINT bundles_flavor_7 CHECK (((artifact_framework_type = 'custom_base_image'::text) = (artifact_image_tag IS NOT NULL))), + CONSTRAINT bundles_flavor_8 CHECK (((flavor = 'cloudpickle_artifact'::text) = (cloudpickle_artifact_load_predict_fn IS NOT NULL))), + CONSTRAINT bundles_flavor_9 CHECK (((flavor = 'cloudpickle_artifact'::text) = (cloudpickle_artifact_load_model_fn IS NOT NULL))) +); + + +-- +-- Name: docker_image_batch_job_bundles; Type: TABLE; Schema: hosted_model_inference; Owner: - +-- + +CREATE TABLE hosted_model_inference.docker_image_batch_job_bundles ( + id text NOT NULL, + name text NOT NULL, + created_by character varying(24) NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + owner character varying(24) NOT NULL, + image_repository text NOT NULL, + image_tag text NOT NULL, + command text[] NOT NULL, + env json NOT NULL, + mount_location text, + cpus text, + memory text, + storage text, + gpus integer, + gpu_type text, + public boolean +); + + +-- +-- Name: endpoints; Type: TABLE; Schema: hosted_model_inference; Owner: - +-- + +CREATE TABLE hosted_model_inference.endpoints ( + id text NOT NULL, + name text, + created_by character varying(24), + created_at timestamp with time zone DEFAULT now(), + last_updated_at timestamp with time zone DEFAULT now(), + current_bundle_id text, + endpoint_metadata jsonb, + creation_task_id text, + endpoint_type text, + destination text, + endpoint_status text, + owner character varying(24) NOT NULL, + public_inference boolean +); + + +-- +-- Name: triggers; Type: TABLE; Schema: hosted_model_inference; Owner: - +-- + +CREATE TABLE hosted_model_inference.triggers ( + id character varying NOT NULL, + name character varying NOT NULL, + owner character varying NOT NULL, + created_by character varying NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + cron_schedule character varying NOT NULL, + docker_image_batch_job_bundle_id character varying NOT NULL, + default_job_config jsonb, + default_job_metadata jsonb +); + + +-- +-- Name: model_artifacts; Type: TABLE; Schema: model; Owner: - +-- + +CREATE TABLE model.model_artifacts ( + id text NOT NULL, + name text NOT NULL, + description text, + is_public boolean NOT NULL, + created_by character varying(24) NOT NULL, + owner character varying(24) NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + input_schema json, + output_schema json, + config json NOT NULL, + location text NOT NULL, + format text NOT NULL, + format_metadata json NOT NULL, + source text NOT NULL, + source_metadata json NOT NULL +); + + +-- +-- Name: model_versions; Type: TABLE; Schema: model; Owner: - +-- + +CREATE TABLE model.model_versions ( + id text NOT NULL, + model_id text NOT NULL, + version_number integer NOT NULL, + tags text[] NOT NULL, + created_by character varying(24) NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL, + launch_model_bundle_id text, + nucleus_model_id text, + metadata json DEFAULT '{}'::json NOT NULL +); + + +-- +-- Name: models; Type: TABLE; Schema: model; Owner: - +-- + +CREATE TABLE model.models ( + id text NOT NULL, + name text NOT NULL, + description text, + task_types text[] NOT NULL, + created_by character varying(24) NOT NULL, + owner character varying(24) NOT NULL, + created_at timestamp with time zone DEFAULT now() NOT NULL +); + + +-- +-- Name: batch_jobs batch_jobs_pkey; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.batch_jobs + ADD CONSTRAINT batch_jobs_pkey PRIMARY KEY (id); + + +-- +-- Name: bundles bundles_pkey; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.bundles + ADD CONSTRAINT bundles_pkey PRIMARY KEY (id); + + +-- +-- Name: docker_image_batch_job_bundles docker_image_batch_job_bundles_pkey; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.docker_image_batch_job_bundles + ADD CONSTRAINT docker_image_batch_job_bundles_pkey PRIMARY KEY (id); + + +-- +-- Name: endpoints endpoint_name_created_by_uc; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.endpoints + ADD CONSTRAINT endpoint_name_created_by_uc UNIQUE (name, created_by); + + +-- +-- Name: endpoints endpoint_name_owner_uc; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.endpoints + ADD CONSTRAINT endpoint_name_owner_uc UNIQUE (name, owner); + + +-- +-- Name: endpoints endpoints_pkey; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.endpoints + ADD CONSTRAINT endpoints_pkey PRIMARY KEY (id); + + +-- +-- Name: triggers triggers_pkey; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.triggers + ADD CONSTRAINT triggers_pkey PRIMARY KEY (id); + + +-- +-- Name: triggers uq_triggers_name_owner; Type: CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.triggers + ADD CONSTRAINT uq_triggers_name_owner UNIQUE (name, owner); + + +-- +-- Name: model_versions launch_model_bundle_id_uc; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_versions + ADD CONSTRAINT launch_model_bundle_id_uc UNIQUE (launch_model_bundle_id); + + +-- +-- Name: model_artifacts model_artifacts_owner_name_uc; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_artifacts + ADD CONSTRAINT model_artifacts_owner_name_uc UNIQUE (owner, name); + + +-- +-- Name: model_artifacts model_artifacts_pkey; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_artifacts + ADD CONSTRAINT model_artifacts_pkey PRIMARY KEY (id); + + +-- +-- Name: model_versions model_id_version_number_uc; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_versions + ADD CONSTRAINT model_id_version_number_uc UNIQUE (model_id, version_number); + + +-- +-- Name: model_versions model_versions_pkey; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_versions + ADD CONSTRAINT model_versions_pkey PRIMARY KEY (id); + + +-- +-- Name: models models_owner_name_uc; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.models + ADD CONSTRAINT models_owner_name_uc UNIQUE (owner, name); + + +-- +-- Name: models models_pkey; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.models + ADD CONSTRAINT models_pkey PRIMARY KEY (id); + + +-- +-- Name: model_versions nucleus_model_id_uc; Type: CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_versions + ADD CONSTRAINT nucleus_model_id_uc UNIQUE (nucleus_model_id); + + +-- +-- Name: endpoint_name_llm_uc; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE UNIQUE INDEX endpoint_name_llm_uc ON hosted_model_inference.endpoints USING btree (name) WHERE (endpoint_metadata ? '_llm'::text); + + +-- +-- Name: idx_endpoint_metadata; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX idx_endpoint_metadata ON hosted_model_inference.endpoints USING gin (endpoint_metadata); + + +-- +-- Name: idx_trigger_name; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX idx_trigger_name ON hosted_model_inference.triggers USING btree (name); + + +-- +-- Name: ix_hosted_model_inference_batch_jobs_created_by; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_batch_jobs_created_by ON hosted_model_inference.batch_jobs USING btree (created_by); + + +-- +-- Name: ix_hosted_model_inference_batch_jobs_owner; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_batch_jobs_owner ON hosted_model_inference.batch_jobs USING btree (owner); + + +-- +-- Name: ix_hosted_model_inference_bundles_created_by; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_bundles_created_by ON hosted_model_inference.bundles USING btree (created_by); + + +-- +-- Name: ix_hosted_model_inference_bundles_name; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_bundles_name ON hosted_model_inference.bundles USING btree (name); + + +-- +-- Name: ix_hosted_model_inference_docker_image_batch_job_bundle_79a0; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_docker_image_batch_job_bundle_79a0 ON hosted_model_inference.docker_image_batch_job_bundles USING btree (created_by); + + +-- +-- Name: ix_hosted_model_inference_docker_image_batch_job_bundles_owner; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_docker_image_batch_job_bundles_owner ON hosted_model_inference.docker_image_batch_job_bundles USING btree (owner); + + +-- +-- Name: ix_hosted_model_inference_endpoints_created_by; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_endpoints_created_by ON hosted_model_inference.endpoints USING btree (created_by); + + +-- +-- Name: ix_hosted_model_inference_endpoints_name; Type: INDEX; Schema: hosted_model_inference; Owner: - +-- + +CREATE INDEX ix_hosted_model_inference_endpoints_name ON hosted_model_inference.endpoints USING btree (name); + + +-- +-- Name: ix_model_model_artifacts_created_by; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_created_by ON model.model_artifacts USING btree (created_by); + + +-- +-- Name: ix_model_model_artifacts_description; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_description ON model.model_artifacts USING btree (description); + + +-- +-- Name: ix_model_model_artifacts_format; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_format ON model.model_artifacts USING btree (format); + + +-- +-- Name: ix_model_model_artifacts_is_public; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_is_public ON model.model_artifacts USING btree (is_public); + + +-- +-- Name: ix_model_model_artifacts_name; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_name ON model.model_artifacts USING btree (name); + + +-- +-- Name: ix_model_model_artifacts_owner; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_owner ON model.model_artifacts USING btree (owner); + + +-- +-- Name: ix_model_model_artifacts_source; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_artifacts_source ON model.model_artifacts USING btree (source); + + +-- +-- Name: ix_model_model_versions_created_by; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_versions_created_by ON model.model_versions USING btree (created_by); + + +-- +-- Name: ix_model_model_versions_model_id; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_versions_model_id ON model.model_versions USING btree (model_id); + + +-- +-- Name: ix_model_model_versions_tags; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_versions_tags ON model.model_versions USING btree (tags); + + +-- +-- Name: ix_model_model_versions_version_number; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_model_versions_version_number ON model.model_versions USING btree (version_number); + + +-- +-- Name: ix_model_models_created_by; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_models_created_by ON model.models USING btree (created_by); + + +-- +-- Name: ix_model_models_description; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_models_description ON model.models USING btree (description); + + +-- +-- Name: ix_model_models_name; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_models_name ON model.models USING btree (name); + + +-- +-- Name: ix_model_models_owner; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_models_owner ON model.models USING btree (owner); + + +-- +-- Name: ix_model_models_task_types; Type: INDEX; Schema: model; Owner: - +-- + +CREATE INDEX ix_model_models_task_types ON model.models USING btree (task_types); + + +-- +-- Name: batch_jobs batch_jobs_model_bundle_id_fkey; Type: FK CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.batch_jobs + ADD CONSTRAINT batch_jobs_model_bundle_id_fkey FOREIGN KEY (model_bundle_id) REFERENCES hosted_model_inference.bundles(id); + + +-- +-- Name: batch_jobs batch_jobs_model_endpoint_id_fkey; Type: FK CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.batch_jobs + ADD CONSTRAINT batch_jobs_model_endpoint_id_fkey FOREIGN KEY (model_endpoint_id) REFERENCES hosted_model_inference.endpoints(id) ON DELETE SET NULL; + + +-- +-- Name: endpoints endpoints_current_bundle_id_fkey; Type: FK CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.endpoints + ADD CONSTRAINT endpoints_current_bundle_id_fkey FOREIGN KEY (current_bundle_id) REFERENCES hosted_model_inference.bundles(id); + + +-- +-- Name: triggers triggers_docker_image_batch_job_bundle_id_fkey; Type: FK CONSTRAINT; Schema: hosted_model_inference; Owner: - +-- + +ALTER TABLE ONLY hosted_model_inference.triggers + ADD CONSTRAINT triggers_docker_image_batch_job_bundle_id_fkey FOREIGN KEY (docker_image_batch_job_bundle_id) REFERENCES hosted_model_inference.docker_image_batch_job_bundles(id); + + +-- +-- Name: model_versions model_versions_launch_model_bundle_id_fkey; Type: FK CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_versions + ADD CONSTRAINT model_versions_launch_model_bundle_id_fkey FOREIGN KEY (launch_model_bundle_id) REFERENCES hosted_model_inference.bundles(id); + + +-- +-- Name: model_versions model_versions_model_id_fkey; Type: FK CONSTRAINT; Schema: model; Owner: - +-- + +ALTER TABLE ONLY model.model_versions + ADD CONSTRAINT model_versions_model_id_fkey FOREIGN KEY (model_id) REFERENCES model.models(id); + + +-- +-- PostgreSQL database dump complete +-- + diff --git a/model-engine/model_engine_server/db/migrations/run_database_migration.sh b/model-engine/model_engine_server/db/migrations/run_database_migration.sh new file mode 100755 index 00000000..8b25f20e --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/run_database_migration.sh @@ -0,0 +1,10 @@ +#!/bin/bash + +# Get the directory of this script +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +# Change directory to the directory of this script +cd $DIR + +# Runs database migration +alembic upgrade head \ No newline at end of file diff --git a/model-engine/model_engine_server/db/migrations/stamp_initial_schema.sh b/model-engine/model_engine_server/db/migrations/stamp_initial_schema.sh new file mode 100755 index 00000000..bf7d3781 --- /dev/null +++ b/model-engine/model_engine_server/db/migrations/stamp_initial_schema.sh @@ -0,0 +1,13 @@ +#!/bin/bash + +# Usage +# ML_INFRA_DATABASE_URL="postgresql://postgres:password@localhost:54320/postgres" bash stamp_initial_schema.sh + +# Get the directory of this script +DIR="$( cd "$( dirname "${BASH_SOURCE[0]}" )" >/dev/null 2>&1 && pwd )" + +# Change directory to the directory of this script +cd $DIR + +# Stamps initial revision to new table +alembic stamp fa3267c80731 \ No newline at end of file diff --git a/server/llm_engine_server/db/models/__init__.py b/model-engine/model_engine_server/db/models/__init__.py similarity index 68% rename from server/llm_engine_server/db/models/__init__.py rename to model-engine/model_engine_server/db/models/__init__.py index bd6a9788..e7a62852 100644 --- a/server/llm_engine_server/db/models/__init__.py +++ b/model-engine/model_engine_server/db/models/__init__.py @@ -1,6 +1,6 @@ from typing import Sequence -from .llm_engine import BatchJob, Bundle, DockerImageBatchJobBundle, Endpoint +from .hosted_model_inference import BatchJob, Bundle, DockerImageBatchJobBundle, Endpoint, Trigger from .model import Model, ModelArtifact, ModelVersion __all__: Sequence[str] = [ @@ -11,4 +11,5 @@ "Model", "ModelArtifact", "ModelVersion", + "Trigger", ] diff --git a/server/llm_engine_server/domain/authorization/__init__.py b/model-engine/model_engine_server/db/models/common/__init__.py similarity index 100% rename from server/llm_engine_server/domain/authorization/__init__.py rename to model-engine/model_engine_server/db/models/common/__init__.py diff --git a/server/llm_engine_server/db/models/common/query.py b/model-engine/model_engine_server/db/models/common/query.py similarity index 100% rename from server/llm_engine_server/db/models/common/query.py rename to model-engine/model_engine_server/db/models/common/query.py diff --git a/server/llm_engine_server/db/models/common/record.py b/model-engine/model_engine_server/db/models/common/record.py similarity index 92% rename from server/llm_engine_server/db/models/common/record.py rename to model-engine/model_engine_server/db/models/common/record.py index ae9602df..d2ecd2ce 100644 --- a/server/llm_engine_server/db/models/common/record.py +++ b/model-engine/model_engine_server/db/models/common/record.py @@ -2,9 +2,9 @@ from typing import Generic, Optional, Sequence, TypeVar -from llm_engine_server.db.base import Base -from llm_engine_server.db.models.common.query import Query -from llm_engine_server.db.models.exceptions import EntityNotFoundError +from model_engine_server.db.base import Base +from model_engine_server.db.models.common.query import Query +from model_engine_server.db.models.exceptions import EntityNotFoundError from sqlalchemy import select from sqlalchemy.orm import Session diff --git a/server/llm_engine_server/db/models/constants.py b/model-engine/model_engine_server/db/models/constants.py similarity index 100% rename from server/llm_engine_server/db/models/constants.py rename to model-engine/model_engine_server/db/models/constants.py diff --git a/server/llm_engine_server/db/models/exceptions.py b/model-engine/model_engine_server/db/models/exceptions.py similarity index 100% rename from server/llm_engine_server/db/models/exceptions.py rename to model-engine/model_engine_server/db/models/exceptions.py diff --git a/server/llm_engine_server/db/models/llm_engine.py b/model-engine/model_engine_server/db/models/hosted_model_inference.py similarity index 93% rename from server/llm_engine_server/db/models/llm_engine.py rename to model-engine/model_engine_server/db/models/hosted_model_inference.py index e508c188..7661be46 100644 --- a/server/llm_engine_server/db/models/llm_engine.py +++ b/model-engine/model_engine_server/db/models/hosted_model_inference.py @@ -18,7 +18,7 @@ from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import relationship, selectinload from sqlalchemy.sql import func, text -from sqlalchemy.sql.expression import update +from sqlalchemy.sql.expression import delete, update from sqlalchemy.sql.schema import CheckConstraint, Index, UniqueConstraint from xid import XID @@ -105,7 +105,7 @@ class Bundle(Base): CheckConstraint( "(flavor = 'triton_enhanced_runnable_image') = (triton_enhanced_runnable_image_readiness_initial_delay_seconds IS NOT NULL)" ), - {"schema": "llm_engine"}, + {"schema": "hosted_model_inference"}, ) id = Column(Text, primary_key=True) @@ -146,6 +146,9 @@ class Bundle(Base): runnable_image_env = Column(JSON, nullable=True) runnable_image_protocol = Column(Text, nullable=True) runnable_image_readiness_initial_delay_seconds = Column(Integer, nullable=True) + runnable_image_extra_routes = Column(ARRAY(Text), nullable=True) + runnable_image_worker_command = Column(ARRAY(Text), nullable=True) + runnable_image_worker_env = Column(JSON, nullable=True) # Streaming Enhanced Runnable Image fields streaming_enhanced_runnable_image_streaming_command = Column(ARRAY(Text), nullable=True) @@ -205,6 +208,9 @@ def __init__( runnable_image_env: Optional[Dict[str, Any]] = None, runnable_image_protocol: Optional[str] = None, runnable_image_readiness_initial_delay_seconds: Optional[int] = None, + runnable_image_extra_routes: Optional[List[str]] = None, + runnable_image_worker_command: Optional[List[str]] = None, + runnable_image_worker_env: Optional[Dict[str, Any]] = None, # Streaming Enhanced Runnable Image fields streaming_enhanced_runnable_image_streaming_command: Optional[List[str]] = None, streaming_enhanced_runnable_image_streaming_predict_route: Optional[str] = None, @@ -260,6 +266,9 @@ def __init__( self.runnable_image_healthcheck_route = runnable_image_healthcheck_route self.runnable_image_env = runnable_image_env self.runnable_image_protocol = runnable_image_protocol + self.runnable_image_extra_routes = runnable_image_extra_routes + self.runnable_image_worker_command = runnable_image_worker_command + self.runnable_image_worker_env = runnable_image_worker_env self.runnable_image_readiness_initial_delay_seconds = ( runnable_image_readiness_initial_delay_seconds ) @@ -433,7 +442,7 @@ class Endpoint(Base): unique=True, postgresql_where=text("endpoint_metadata ? '_llm'"), ), - {"schema": "llm_engine"}, + {"schema": "hosted_model_inference"}, ) id = Column(Text, primary_key=True) @@ -441,7 +450,7 @@ class Endpoint(Base): created_by = Column(String(SHORT_STRING), index=True) created_at = Column(DateTime(timezone=True), server_default=func.now()) last_updated_at = Column(DateTime(timezone=True), server_default=func.now(), onupdate=time_now) - current_bundle_id = Column(Text, ForeignKey("llm_engine.bundles.id")) + current_bundle_id = Column(Text, ForeignKey("hosted_model_inference.bundles.id")) endpoint_metadata = Column(JSONB, default={}) creation_task_id = Column(Text) endpoint_type = Column(Text, default="async") @@ -623,7 +632,7 @@ async def delete(cls, session: AsyncSession, endpoint: "Endpoint") -> None: class BatchJob(Base): __tablename__ = "batch_jobs" - __table_args__ = ({"schema": "llm_engine"},) + __table_args__ = ({"schema": "hosted_model_inference"},) id = Column(Text, primary_key=True) created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False) @@ -632,9 +641,13 @@ class BatchJob(Base): created_by = Column(String(SHORT_STRING), index=True, nullable=False) owner = Column(String(SHORT_STRING), index=True, nullable=False) model_bundle_id = Column( - Text, ForeignKey("llm_engine.bundles.id", ondelete="SET NULL"), nullable=False + Text, + ForeignKey("hosted_model_inference.bundles.id", ondelete="SET NULL"), + nullable=False, + ) + model_endpoint_id = Column( + Text, ForeignKey("hosted_model_inference.endpoints.id"), nullable=True ) - model_endpoint_id = Column(Text, ForeignKey("llm_engine.endpoints.id"), nullable=True) task_ids_location = Column(Text, nullable=True) result_location = Column(Text, nullable=True) @@ -704,7 +717,7 @@ async def update_by_id( class DockerImageBatchJobBundle(Base): __tablename__ = "docker_image_batch_job_bundles" - __table_args__ = ({"schema": "llm_engine"},) + __table_args__ = ({"schema": "hosted_model_inference"},) id = Column("id", Text, primary_key=True) name = Column("name", Text, nullable=False) @@ -806,7 +819,7 @@ class Trigger(Base): __tablename__ = "triggers" __table_args__ = ( UniqueConstraint("name", "owner", name="uq_triggers_name_owner"), - {"schema": "llm_engine"}, + {"schema": "hosted_model_inference"}, ) id = Column("id", String, nullable=False, primary_key=True) @@ -820,7 +833,7 @@ class Trigger(Base): docker_image_batch_job_bundle_id = Column( "docker_image_batch_job_bundle_id", String, - ForeignKey("llm_engine.docker_image_batch_job_bundles.id"), + ForeignKey("hosted_model_inference.docker_image_batch_job_bundles.id"), nullable=False, ) default_job_config = Column("default_job_config", JSONB, nullable=True) @@ -845,3 +858,33 @@ def __init__( self.docker_image_batch_job_bundle_id = docker_image_batch_job_bundle_id self.default_job_config = default_job_config self.default_job_metadata = default_job_metadata + + @classmethod + async def create(cls, session: AsyncSession, trigger: "Trigger") -> None: + session.add(trigger) + await session.commit() + + @classmethod + async def select_all_by_owner(cls, session: AsyncSession, owner: str) -> List["Trigger"]: + triggers = await session.execute(select(Trigger).filter_by(owner=owner)) + return triggers.scalars().all() + + @classmethod + async def select_by_id(cls, session: AsyncSession, trigger_id: str) -> Optional["Trigger"]: + trigger = await session.execute(select(Trigger).filter_by(id=trigger_id)) + return trigger.scalar_one_or_none() + + @classmethod + async def update_by_id( + cls, session: AsyncSession, trigger_id: str, kwargs: Dict[str, Any] + ) -> None: + update_kwargs = kwargs.copy() + stmt = update(Trigger).where(Trigger.id == trigger_id).values(**update_kwargs) + await session.execute(stmt) + await session.commit() + + @classmethod + async def delete_by_id(cls, session: AsyncSession, trigger_id: str) -> None: + stmt = delete(Trigger).where(Trigger.id == trigger_id) + await session.execute(stmt) + await session.commit() diff --git a/server/llm_engine_server/db/models/model.py b/model-engine/model_engine_server/db/models/model.py similarity index 93% rename from server/llm_engine_server/db/models/model.py rename to model-engine/model_engine_server/db/models/model.py index 043e170b..d5c6fef9 100644 --- a/server/llm_engine_server/db/models/model.py +++ b/model-engine/model_engine_server/db/models/model.py @@ -106,9 +106,9 @@ class ModelVersion(Base): Column("model_id", Text, ForeignKey("model.models.id"), index=True, nullable=False), Column("version_number", Integer, index=True, nullable=False), Column( - "llm_engine_model_bundle_id", + "launch_model_bundle_id", Text, - # ForeignKey("llm_engine.bundles.id"), # This is currently breaking tests. + # ForeignKey("hosted_model_inference.bundles.id"), # This is currently breaking tests. index=True, nullable=True, ), @@ -116,14 +116,9 @@ class ModelVersion(Base): Column("tags", ARRAY(Text), index=True, nullable=False), Column("metadata", JSON, index=False, server_default="{}"), Column("created_by", String(SHORT_STRING), index=True, nullable=False), - Column( - "created_at", - DateTime(timezone=True), - server_default=func.now(), - nullable=False, - ), + Column("created_at", DateTime(timezone=True), server_default=func.now(), nullable=False), UniqueConstraint("model_id", "version_number", name="model_id_version_number_uc"), - UniqueConstraint("llm_engine_model_bundle_id", name="llm_engine_model_bundle_id_uc"), + UniqueConstraint("launch_model_bundle_id", name="launch_model_bundle_id_uc"), UniqueConstraint("nucleus_model_id", name="nucleus_model_id_uc"), schema="model", ) @@ -132,7 +127,7 @@ def __init__( self, model_id: Optional[str] = None, version_number: Optional[int] = None, - llm_engine_model_bundle_id: Optional[str] = None, + launch_model_bundle_id: Optional[str] = None, nucleus_model_id: Optional[str] = None, tags: Optional[List[str]] = None, metadata: Optional[Any] = None, @@ -142,7 +137,7 @@ def __init__( self.id = f"mov_{get_xid()}" self.model_id = model_id self.version_number = version_number - self.llm_engine_model_bundle_id = llm_engine_model_bundle_id + self.launch_model_bundle_id = launch_model_bundle_id self.nucleus_model_id = nucleus_model_id self.tags = tags or [] self.metadata = metadata @@ -175,11 +170,11 @@ def select( return models @staticmethod - def select_by_llm_engine_model_bundle_id( - session: Session, llm_engine_model_bundle_id: str + def select_by_launch_model_bundle_id( + session: Session, launch_model_bundle_id: str ) -> Optional["ModelVersion"]: model_version = session.execute( - select(ModelVersion).filter_by(llm_engine_model_bundle_id=llm_engine_model_bundle_id) + select(ModelVersion).filter_by(launch_model_bundle_id=launch_model_bundle_id) ).scalar_one_or_none() return model_version diff --git a/server/llm_engine_server/domain/use_cases/__init__.py b/model-engine/model_engine_server/db/models/utils/__init__.py similarity index 100% rename from server/llm_engine_server/domain/use_cases/__init__.py rename to model-engine/model_engine_server/db/models/utils/__init__.py diff --git a/server/llm_engine_server/db/models/utils/misc.py b/model-engine/model_engine_server/db/models/utils/misc.py similarity index 100% rename from server/llm_engine_server/db/models/utils/misc.py rename to model-engine/model_engine_server/db/models/utils/misc.py diff --git a/server/llm_engine_server/entrypoints/__init__.py b/model-engine/model_engine_server/domain/__init__.py similarity index 100% rename from server/llm_engine_server/entrypoints/__init__.py rename to model-engine/model_engine_server/domain/__init__.py diff --git a/server/llm_engine_server/inference/__init__.py b/model-engine/model_engine_server/domain/authorization/__init__.py similarity index 100% rename from server/llm_engine_server/inference/__init__.py rename to model-engine/model_engine_server/domain/authorization/__init__.py diff --git a/server/llm_engine_server/domain/authorization/scale_authorization_module.py b/model-engine/model_engine_server/domain/authorization/live_authorization_module.py similarity index 67% rename from server/llm_engine_server/domain/authorization/scale_authorization_module.py rename to model-engine/model_engine_server/domain/authorization/live_authorization_module.py index 19eaab5e..f895cefe 100644 --- a/server/llm_engine_server/domain/authorization/scale_authorization_module.py +++ b/model-engine/model_engine_server/domain/authorization/live_authorization_module.py @@ -1,17 +1,19 @@ -from llm_engine_server.common.dtos.model_bundles import ( +from model_engine_server.common.dtos.model_bundles import ( CreateModelBundleV1Request, CreateModelBundleV2Request, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.domain.entities import CustomFramework, ModelBundleFrameworkType, OwnedEntity -from llm_engine_server.domain.entities.model_bundle_entity import RunnableImageLike -from llm_engine_server.domain.entities.model_endpoint_entity import ModelEndpointRecord - -LLM_ENGINE_INTEGRATION_TEST_USER: str = "62bc820451dbea002b1c5421" +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.config import infra_config +from model_engine_server.domain.entities import ( + CustomFramework, + ModelBundleFrameworkType, + ModelEndpointRecord, + OwnedEntity, + RunnableImageLike, +) -class ScaleAuthorizationModule: +class LiveAuthorizationModule: """ This class contains authorization utilities. All methods expect User objects given from authn. """ @@ -29,13 +31,9 @@ def check_access_create_bundle_v1(user: User, request: CreateModelBundleV1Reques def check_access_create_bundle_v2(user: User, request: CreateModelBundleV2Request) -> bool: """Checks whether the provided user is authorized to create the requested model bundle.""" # External customers cannot use custom images. - return ( - user.is_privileged_user - or user.user_id == LLM_ENGINE_INTEGRATION_TEST_USER - or ( - not isinstance(request.flavor, RunnableImageLike) - and not isinstance(request.flavor.framework, CustomFramework) - ) + return user.is_privileged_user or ( + not isinstance(request.flavor, RunnableImageLike) + and not isinstance(request.flavor.framework, CustomFramework) ) @staticmethod @@ -52,12 +50,12 @@ def check_access_write_owned_entity(user: User, owned_entity: OwnedEntity) -> bo @staticmethod def get_aws_role_for_user(user: User) -> str: """Returns the AWS role that should be assumed with the user's resources.""" - return ml_infra_config().profile_ml_inference_worker + return infra_config().profile_ml_inference_worker @staticmethod def get_s3_bucket_for_user(user: User) -> str: """Returns the AWS role that should be assumed with the user's resources.""" - return ml_infra_config().s3_bucket + return infra_config().s3_bucket @staticmethod def check_endpoint_public_inference_for_user( diff --git a/server/llm_engine_server/domain/entities/__init__.py b/model-engine/model_engine_server/domain/entities/__init__.py similarity index 85% rename from server/llm_engine_server/domain/entities/__init__.py rename to model-engine/model_engine_server/domain/entities/__init__.py index a906eb83..a3ed7393 100644 --- a/server/llm_engine_server/domain/entities/__init__.py +++ b/model-engine/model_engine_server/domain/entities/__init__.py @@ -6,10 +6,13 @@ BatchJobRecord, BatchJobSerializationFormat, BatchJobStatus, + DockerImageBatchJob, ) -from .common_types import CpuSpecificationType, StorageSpecificationType +from .common_types import CpuSpecificationType, FineTuneHparamValueType, StorageSpecificationType +from .file_entity import FileMetadata from .gpu_type import GpuType from .llm_entity import LLMInferenceFramework, LLMMetadata, LLMSource, Quantization +from .llm_fine_tune_entity import LLMFineTuneEvent from .model_bundle_entity import ( ArtifactLike, CloudpickleArtifactFlavor, @@ -44,8 +47,9 @@ ModelEndpointUserConfigState, ) from .owned_entity import OwnedEntity +from .trigger_entity import Trigger -__all__: Sequence[str] = ( +__all__: Sequence[str] = [ "ArtifactLike", "BatchJob", "BatchJobProgress", @@ -58,7 +62,11 @@ "CloudpickleArtifactFlavor", "CpuSpecificationType", "CustomFramework", + "DockerImageBatchJob", + "FileMetadata", "GpuType", + "FineTuneHparamValueType", + "LLMFineTuneEvent", "LLMInferenceFramework", "LLMMetadata", "LLMSource", @@ -86,6 +94,7 @@ "StorageSpecificationType", "StreamingEnhancedRunnableImageFlavor", "TensorflowFramework", + "Trigger", "TritonEnhancedRunnableImageFlavor", "ZipArtifactFlavor", -) +] diff --git a/server/llm_engine_server/domain/entities/batch_job_entity.py b/model-engine/model_engine_server/domain/entities/batch_job_entity.py similarity index 54% rename from server/llm_engine_server/domain/entities/batch_job_entity.py rename to model-engine/model_engine_server/domain/entities/batch_job_entity.py index fe16398e..a1b2ea1b 100644 --- a/server/llm_engine_server/domain/entities/batch_job_entity.py +++ b/model-engine/model_engine_server/domain/entities/batch_job_entity.py @@ -1,11 +1,11 @@ from datetime import datetime from enum import Enum -from typing import Optional +from typing import Dict, Optional -from llm_engine_server.domain.entities.model_bundle_entity import ModelBundle -from llm_engine_server.domain.entities.model_endpoint_entity import ModelEndpoint -from llm_engine_server.domain.entities.owned_entity import OwnedEntity -from pydantic import BaseModel +from model_engine_server.common.pydantic_types import BaseModel +from model_engine_server.domain.entities.model_bundle_entity import ModelBundle +from model_engine_server.domain.entities.model_endpoint_entity import ModelEndpoint +from model_engine_server.domain.entities.owned_entity import OwnedEntity class BatchJobStatus(str, Enum): @@ -26,24 +26,24 @@ class BatchJobSerializationFormat(str, Enum): class BatchJobRecord(OwnedEntity): id: str created_at: datetime - completed_at: Optional[datetime] + completed_at: Optional[datetime] = None status: BatchJobStatus created_by: str owner: str model_bundle: ModelBundle - model_endpoint_id: Optional[str] - task_ids_location: Optional[str] - result_location: Optional[str] + model_endpoint_id: Optional[str] = None + task_ids_location: Optional[str] = None + result_location: Optional[str] = None class BatchJobProgress(BaseModel): - num_tasks_pending: Optional[int] - num_tasks_completed: Optional[int] + num_tasks_pending: Optional[int] = None + num_tasks_completed: Optional[int] = None class BatchJob(BaseModel): record: BatchJobRecord - model_endpoint: Optional[ModelEndpoint] + model_endpoint: Optional[ModelEndpoint] = None progress: BatchJobProgress @@ -57,5 +57,8 @@ class DockerImageBatchJob(BaseModel): created_by: str owner: str created_at: datetime - completed_at: Optional[datetime] + completed_at: Optional[datetime] = None status: BatchJobStatus # the status map relatively nicely onto BatchJobStatus + annotations: Optional[Dict[str, str]] = None + override_job_max_runtime_s: Optional[int] = None + num_workers: Optional[int] = 1 diff --git a/model-engine/model_engine_server/domain/entities/common_types.py b/model-engine/model_engine_server/domain/entities/common_types.py new file mode 100644 index 00000000..ea0c2240 --- /dev/null +++ b/model-engine/model_engine_server/domain/entities/common_types.py @@ -0,0 +1,7 @@ +from typing import Any, Dict, Union + +CpuSpecificationType = Union[str, int, float] +StorageSpecificationType = Union[str, int, float] # TODO(phil): we can make this more specific. +FineTuneHparamValueType = Union[ + str, int, float, Dict[str, Any] +] # should just make this Any if we need to add more diff --git a/model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py b/model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py new file mode 100644 index 00000000..a3914e3f --- /dev/null +++ b/model-engine/model_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py @@ -0,0 +1,26 @@ +import datetime +from typing import Dict, List, Optional + +from model_engine_server.common.pydantic_types import ConfigDict +from model_engine_server.domain.entities import GpuType +from model_engine_server.domain.entities.owned_entity import OwnedEntity + + +class DockerImageBatchJobBundle(OwnedEntity): + id: str + name: str + created_by: str + created_at: datetime.datetime + owner: str + image_repository: str + image_tag: str + command: List[str] + env: Dict[str, str] + mount_location: Optional[str] = None + cpus: Optional[str] = None + memory: Optional[str] = None + storage: Optional[str] = None + gpus: Optional[int] = None + gpu_type: Optional[GpuType] = None + public: Optional[bool] = None + model_config = ConfigDict(from_attributes=True) diff --git a/model-engine/model_engine_server/domain/entities/file_entity.py b/model-engine/model_engine_server/domain/entities/file_entity.py new file mode 100644 index 00000000..f4d5a1f4 --- /dev/null +++ b/model-engine/model_engine_server/domain/entities/file_entity.py @@ -0,0 +1,15 @@ +from datetime import datetime + +from model_engine_server.common.pydantic_types import BaseModel + + +class FileMetadata(BaseModel): + """ + This is the entity-layer class for a File from the Files API. + """ + + id: str + filename: str + size: int + owner: str + updated_at: datetime diff --git a/model-engine/model_engine_server/domain/entities/gpu_type.py b/model-engine/model_engine_server/domain/entities/gpu_type.py new file mode 100644 index 00000000..6c686c01 --- /dev/null +++ b/model-engine/model_engine_server/domain/entities/gpu_type.py @@ -0,0 +1,13 @@ +from enum import Enum + + +class GpuType(str, Enum): + """Lists allowed GPU types for Launch.""" + + NVIDIA_TESLA_T4 = "nvidia-tesla-t4" + NVIDIA_AMPERE_A10 = "nvidia-ampere-a10" + NVIDIA_AMPERE_A100 = "nvidia-ampere-a100" + NVIDIA_AMPERE_A100E = "nvidia-ampere-a100e" + NVIDIA_HOPPER_H100 = "nvidia-hopper-h100" + NVIDIA_HOPPER_H100_1G_20GB = "nvidia-hopper-h100-1g20gb" + NVIDIA_HOPPER_H100_3G_40GB = "nvidia-hopper-h100-3g40gb" diff --git a/server/llm_engine_server/domain/entities/llm_entity.py b/model-engine/model_engine_server/domain/entities/llm_entity.py similarity index 75% rename from server/llm_engine_server/domain/entities/llm_entity.py rename to model-engine/model_engine_server/domain/entities/llm_entity.py index f9062709..937a739f 100644 --- a/server/llm_engine_server/domain/entities/llm_entity.py +++ b/model-engine/model_engine_server/domain/entities/llm_entity.py @@ -10,10 +10,14 @@ class LLMSource(str, Enum): class LLMInferenceFramework(str, Enum): DEEPSPEED = "deepspeed" TEXT_GENERATION_INFERENCE = "text_generation_inference" + VLLM = "vllm" + LIGHTLLM = "lightllm" + TENSORRT_LLM = "tensorrt_llm" class Quantization(str, Enum): BITSANDBYTES = "bitsandbytes" + AWQ = "awq" @dataclass @@ -24,3 +28,5 @@ class LLMMetadata: inference_framework_image_tag: str num_shards: int quantize: Optional[Quantization] = None + checkpoint_path: Optional[str] = None + chat_template_override: Optional[str] = None diff --git a/model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py b/model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py new file mode 100644 index 00000000..14a0c97a --- /dev/null +++ b/model-engine/model_engine_server/domain/entities/llm_fine_tune_entity.py @@ -0,0 +1,17 @@ +from typing import Any, Dict, List, Optional + +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict + + +class LLMFineTuneTemplate(BaseModel): + docker_image_batch_job_bundle_id: str + launch_endpoint_config: Dict[str, Any] + default_hparams: Dict[str, Any] + required_params: List[str] + model_config = ConfigDict(from_attributes=True) + + +class LLMFineTuneEvent(BaseModel): + timestamp: Optional[float] = None + message: str + level: str diff --git a/server/llm_engine_server/domain/entities/model_bundle_entity.py b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py similarity index 76% rename from server/llm_engine_server/domain/entities/model_bundle_entity.py rename to model-engine/model_engine_server/domain/entities/model_bundle_entity.py index 70e03494..40e26670 100644 --- a/server/llm_engine_server/domain/entities/model_bundle_entity.py +++ b/model-engine/model_engine_server/domain/entities/model_bundle_entity.py @@ -3,9 +3,9 @@ from enum import Enum from typing import Any, Dict, List, Optional, Union -from llm_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME -from llm_engine_server.domain.entities.owned_entity import OwnedEntity -from pydantic import BaseModel, Field, root_validator +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME +from model_engine_server.common.pydantic_types import BaseModel, ConfigDict, Field, model_validator +from model_engine_server.domain.entities.owned_entity import OwnedEntity from typing_extensions import Literal @@ -18,6 +18,7 @@ class ModelBundlePackagingType(str, Enum): CLOUDPICKLE = "cloudpickle" ZIP = "zip" + LIRA = "lira" class ModelBundleFrameworkType(str, Enum): @@ -37,12 +38,12 @@ class ModelBundleEnvironmentParams(BaseModel): """ framework_type: ModelBundleFrameworkType - pytorch_image_tag: Optional[str] # for pytorch - tensorflow_version: Optional[str] # for tensorflow - ecr_repo: Optional[str] # for custom base image - image_tag: Optional[str] # for custom base image + pytorch_image_tag: Optional[str] = None # for pytorch + tensorflow_version: Optional[str] = None # for tensorflow + ecr_repo: Optional[str] = None # for custom base image + image_tag: Optional[str] = None # for custom base image - @root_validator + @model_validator(mode="before") @classmethod def validate_fields_present_for_framework_type(cls, field_values): """ @@ -71,12 +72,7 @@ def validate_fields_present_for_framework_type(cls, field_values): ) return field_values - class Config: - """ - Model Bundle Environment Params Config class. - """ - - orm_mode = True + model_config = ConfigDict(from_attributes=True) class PytorchFramework(BaseModel): @@ -126,7 +122,7 @@ class ArtifactLike(BaseModel, ABC): framework: Union[PytorchFramework, TensorflowFramework, CustomFramework] = Field( ..., discriminator="framework_type" ) - app_config: Optional[Dict[str, Any]] + app_config: Optional[Dict[str, Any]] = None location: str @@ -158,9 +154,12 @@ class RunnableImageLike(BaseModel, ABC): command: List[str] predict_route: str = "/predict" healthcheck_route: str = "/readyz" - env: Optional[Dict[str, str]] + env: Optional[Dict[str, str]] = None protocol: Literal["http"] # TODO: add support for other protocols (e.g. grpc) readiness_initial_delay_seconds: int = 120 + extra_routes: List[str] = Field(default_factory=list) + worker_command: Optional[List[str]] = None + worker_env: Optional[Dict[str, str]] = None class RunnableImageFlavor(RunnableImageLike): @@ -176,11 +175,11 @@ class TritonEnhancedRunnableImageFlavor(RunnableImageLike): flavor: Literal[ModelBundleFlavorType.TRITON_ENHANCED_RUNNABLE_IMAGE] triton_model_repository: str - triton_model_replicas: Optional[Dict[str, int]] + triton_model_replicas: Optional[Dict[str, str]] = None triton_num_cpu: float triton_commit_tag: str - triton_storage: Optional[str] - triton_memory: Optional[str] + triton_storage: Optional[str] = None + triton_memory: Optional[str] = None triton_readiness_initial_delay_seconds: int = 300 # will default to 300 seconds @@ -216,27 +215,28 @@ class ModelBundle(OwnedEntity): created_at: datetime.datetime metadata: Dict[str, Any] model_artifact_ids: List[str] - schema_location: Optional[str] + schema_location: Optional[str] = None owner: str flavor: ModelBundleFlavors = Field(..., discriminator="flavor") # LEGACY FIELDS - requirements: Optional[List[str]] # FIXME: Delete - location: Optional[str] # FIXME: Delete - env_params: Optional[ModelBundleEnvironmentParams] # FIXME: Delete - packaging_type: Optional[ModelBundlePackagingType] # FIXME: Delete - app_config: Optional[Dict[str, Any]] # FIXME: Delete - - class Config: - """ - Model Bundle Config class. - """ - - orm_mode = True + requirements: Optional[List[str]] = None # FIXME: Delete + location: Optional[str] = None # FIXME: Delete + env_params: Optional[ModelBundleEnvironmentParams] = None # FIXME: Delete + packaging_type: Optional[ModelBundlePackagingType] = None # FIXME: Delete + app_config: Optional[Dict[str, Any]] = None # FIXME: Delete + model_config = ConfigDict(from_attributes=True) def is_runnable(self) -> bool: - """True iff the model bundle calls for it.""" - return isinstance(self.flavor, RunnableImageLike) + """True iff the model bundle calls for it. + + If it is set to 'true', then this function will only return true if the :param:`model_bundle`'s + packaging_type is `ModelBundlePackagingType.LIRA` or if the :param:`model_bundle`'s flavor is + an instance of `RunnableImageLike`. Otherwise, it will return false. + """ + return self.packaging_type == ModelBundlePackagingType.LIRA or isinstance( + self.flavor, RunnableImageLike + ) def celery_task_name(self): return LIRA_CELERY_TASK_NAME if self.is_runnable() else DEFAULT_CELERY_TASK_NAME diff --git a/server/llm_engine_server/domain/entities/model_endpoint_entity.py b/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py similarity index 60% rename from server/llm_engine_server/domain/entities/model_endpoint_entity.py rename to model-engine/model_engine_server/domain/entities/model_endpoint_entity.py index ec16e2b9..f4e4db3c 100644 --- a/server/llm_engine_server/domain/entities/model_endpoint_entity.py +++ b/model-engine/model_engine_server/domain/entities/model_endpoint_entity.py @@ -3,16 +3,16 @@ from typing import Any, Dict, List, Optional, Union from fastapi.openapi.models import OpenAPI -from llm_engine_server.common import dict_not_none -from llm_engine_server.common.serialization_utils import b64_to_python_json, python_json_to_b64 -from llm_engine_server.domain.entities.common_types import ( +from model_engine_server.common import dict_not_none +from model_engine_server.common.pydantic_types import BaseModel, Field, RootModel +from model_engine_server.common.serialization_utils import b64_to_python_json, python_json_to_b64 +from model_engine_server.domain.entities.common_types import ( CpuSpecificationType, StorageSpecificationType, ) -from llm_engine_server.domain.entities.gpu_type import GpuType -from llm_engine_server.domain.entities.model_bundle_entity import ModelBundle -from llm_engine_server.domain.entities.owned_entity import OwnedEntity -from pydantic import BaseModel, Field +from model_engine_server.domain.entities.gpu_type import GpuType +from model_engine_server.domain.entities.model_bundle_entity import ModelBundle +from model_engine_server.domain.entities.owned_entity import OwnedEntity from typing_extensions import Literal ModelEndpointsSchema = OpenAPI @@ -37,14 +37,21 @@ class ModelEndpointStatus(str, Enum): class ModelEndpointResourceState(BaseModel): """ This is the entity-layer class for the resource settings per worker of a Model Endpoint. + Note: in the multinode case, there are multiple "nodes" per "worker". + "Nodes" is analogous to a single k8s pod that may take up all the GPUs on a single machine. + "Workers" is the smallest unit that a request can be made to, and consists of one leader "node" and + multiple follower "nodes" (named "worker" in the k8s LeaderWorkerSet definition). + cpus/gpus/memory/storage are per-node, thus the total consumption by a "worker" + is cpus/gpus/etc. multiplied by nodes_per_worker. """ cpus: CpuSpecificationType # TODO(phil): try to use decimal.Decimal gpus: int = Field(..., ge=0) memory: StorageSpecificationType - gpu_type: Optional[GpuType] - storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] + gpu_type: Optional[GpuType] = None + storage: Optional[StorageSpecificationType] = None + nodes_per_worker: int = Field(..., ge=1) # Multinode support. >1 = multinode. + optimize_costs: Optional[bool] = None class ModelEndpointDeploymentState(BaseModel): @@ -71,8 +78,8 @@ class CallbackmTLSAuth(BaseModel): key: str -class CallbackAuth(BaseModel): - __root__: Union[CallbackBasicAuth, CallbackmTLSAuth] = Field(..., discriminator="kind") +class CallbackAuth(RootModel): + root: Union[CallbackBasicAuth, CallbackmTLSAuth] = Field(..., discriminator="kind") class ModelEndpointConfig(BaseModel): @@ -82,10 +89,16 @@ class ModelEndpointConfig(BaseModel): endpoint_name: str bundle_name: str - post_inference_hooks: Optional[List[str]] + post_inference_hooks: Optional[List[str]] = None user_id: Optional[str] = None + billing_queue: Optional[str] = None + billing_tags: Optional[Dict[str, Any]] = None default_callback_url: Optional[str] = None - default_callback_auth: Optional[CallbackAuth] + default_callback_auth: Optional[CallbackAuth] = None + endpoint_id: Optional[str] = None + endpoint_type: Optional[ModelEndpointType] = None + bundle_id: Optional[str] = None + labels: Optional[Dict[str, str]] = None def serialize(self) -> str: return python_json_to_b64(dict_not_none(**self.dict())) @@ -96,8 +109,8 @@ def deserialize(serialized_config: str) -> "ModelEndpointConfig": class ModelEndpointUserConfigState(BaseModel): - app_config: Optional[Dict[str, Any]] - endpoint_config: Optional[ModelEndpointConfig] + app_config: Optional[Dict[str, Any]] = None + endpoint_config: Optional[ModelEndpointConfig] = None class ModelEndpointRecord(OwnedEntity): @@ -111,15 +124,15 @@ class ModelEndpointRecord(OwnedEntity): name: str created_by: str created_at: datetime.datetime - last_updated_at: Optional[datetime.datetime] - metadata: Optional[Dict[str, Any]] + last_updated_at: Optional[datetime.datetime] = None + metadata: Optional[Dict[str, Any]] = None creation_task_id: Optional[str] = Field(default=None) endpoint_type: ModelEndpointType destination: str status: ModelEndpointStatus current_model_bundle: ModelBundle owner: str - public_inference: Optional[bool] + public_inference: Optional[bool] = None class ModelEndpointInfraState(BaseModel): @@ -130,14 +143,14 @@ class ModelEndpointInfraState(BaseModel): deployment_name: str aws_role: str results_s3_bucket: str - child_fn_info: Optional[Dict[str, Any]] + child_fn_info: Optional[Dict[str, Any]] = None labels: Dict[str, str] deployment_state: ModelEndpointDeploymentState resource_state: ModelEndpointResourceState user_config_state: ModelEndpointUserConfigState prewarm: Optional[bool] = None - high_priority: Optional[bool] - num_queued_items: Optional[int] + high_priority: Optional[bool] = None + num_queued_items: Optional[int] = None image: str @@ -147,4 +160,4 @@ class ModelEndpoint(BaseModel): """ record: ModelEndpointRecord - infra_state: Optional[ModelEndpointInfraState] + infra_state: Optional[ModelEndpointInfraState] = None diff --git a/server/llm_engine_server/domain/entities/owned_entity.py b/model-engine/model_engine_server/domain/entities/owned_entity.py similarity index 72% rename from server/llm_engine_server/domain/entities/owned_entity.py rename to model-engine/model_engine_server/domain/entities/owned_entity.py index 7ea79a0d..6eaf0737 100644 --- a/server/llm_engine_server/domain/entities/owned_entity.py +++ b/model-engine/model_engine_server/domain/entities/owned_entity.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel +from model_engine_server.common.pydantic_types import BaseModel class OwnedEntity(BaseModel): diff --git a/model-engine/model_engine_server/domain/entities/trigger_entity.py b/model-engine/model_engine_server/domain/entities/trigger_entity.py new file mode 100644 index 00000000..989b44cf --- /dev/null +++ b/model-engine/model_engine_server/domain/entities/trigger_entity.py @@ -0,0 +1,19 @@ +import datetime +from typing import Any, Dict, Optional + +from model_engine_server.common.pydantic_types import ConfigDict +from model_engine_server.domain.entities.owned_entity import OwnedEntity + + +class Trigger(OwnedEntity): + id: str + name: str + owner: str + created_by: str + created_at: datetime.datetime + + cron_schedule: str + docker_image_batch_job_bundle_id: str + default_job_config: Optional[Dict[str, Any]] = None + default_job_metadata: Optional[Dict[str, str]] = None + model_config = ConfigDict(from_attributes=True) diff --git a/model-engine/model_engine_server/domain/exceptions.py b/model-engine/model_engine_server/domain/exceptions.py new file mode 100644 index 00000000..075b4823 --- /dev/null +++ b/model-engine/model_engine_server/domain/exceptions.py @@ -0,0 +1,205 @@ +from dataclasses import dataclass + + +class DomainException(Exception): + """ + Base class for exceptions thrown for domain (business logic) errors. + """ + + +class ObjectAlreadyExistsException(DomainException): + """ + Thrown when the user tries to create a model with a name that already exists. + """ + + +class ObjectNotFoundException(DomainException): + """ + Thrown when a required object is not found, e.g. when creating a version for a nonexistent model + """ + + +class ObjectNotAuthorizedException(DomainException): + """ + Thrown when a user tries to access an object they don't own. + """ + + +class ObjectHasInvalidValueException(DomainException, ValueError): + """ + Thrown when a user tries to create an object with an invalid value. + """ + + +@dataclass +class DockerImageNotFoundException(DomainException): + """ + Thrown when a user tries to specify a custom Docker image that cannot be found. + """ + + repository: str + tag: str + + +class DockerRepositoryNotFoundException(DomainException): + """ + Thrown when a Docker repository that is trying to be accessed doesn't exist. + """ + + +class DockerBuildFailedException(DomainException): + """ + Thrown if the server failed to build a docker image. + """ + + +class ReadOnlyDatabaseException(DomainException): + """ + Thrown if the server attempted to write to a read-only database. + """ + + +class ExistingEndpointOperationInProgressException(DomainException): + """ + Thrown when a user tries to edit an endpoint that has an edit in progress + """ + + def __init__(self, message): + self.message = message + + +class EndpointDeleteFailedException(DomainException): + """ + Thrown if the server failed to delete an endpoint for whatever reason. Indicates a bug serverside + """ + + +class EndpointUnsupportedInferenceTypeException(DomainException): + """ + Thrown if the requested inference type is unsupported by the endpoint. + """ + + +class EndpointUnsupportedRequestException(DomainException): + """ + Throw if the request is unsupported by the endpoint. + """ + + +class EndpointResourceInvalidRequestException(DomainException): + """ + Thrown if the endpoint resource requests are invalid. + """ + + +class EndpointInfraStateNotFound(DomainException): + """ + Thrown if the endpoint infra_state field is expected to be not None but found to be None. + """ + + +class EndpointResourceInfraException(DomainException): + """ + Thrown if the endpoint resource request passes validation, but failed for unhandled reasons. + This corresponds to a 503 error and requires investigation by the Launch team. + """ + + +class EndpointLabelsException(DomainException): + """ + Thrown if the endpoint required labels are missing or wrong. + """ + + +class EndpointBillingTagsMalformedException(DomainException): + """ + Thrown if endpoint billing tags are malformed (i.e. wrong type, wrong keys, etc.) + """ + + +class TooManyRequestsException(DomainException): + """ + Thrown if an endpoint returns a 429 exception for too many requests. + """ + + +class NoHealthyUpstreamException(DomainException): + """ + Thrown if an endpoint returns a 503 exception for no healthy upstream. This can happen if there are zero pods + available to serve the request. + """ + + +class CorruptRecordInfraStateException(DomainException): + """ + Thrown if the data from existing state (i.e. the db, k8s, etc.) is somehow uninterpretable + by the code. This can occur if the state isn't being written to correctly, if we've missed + a migration somewhere, etc. + """ + + +class UpstreamServiceError(DomainException): + """ + Thrown to relay an upstream HTTP service error to the user. + """ + + def __init__(self, status_code: int, content: bytes): + self.status_code = status_code + self.content = content + + +class LLMFineTuningMethodNotImplementedException(DomainException): + """ + Thrown if the requested fine-tuning model/method pair is not implemented. + """ + + +class LLMFineTuningQuotaReached(DomainException): + """ + Thrown if the user has run too many fine-tunes. + """ + + +class InvalidRequestException(DomainException): + """ + Thrown if the user request is invalid. + """ + + +class CronSyntaxException(DomainException): + """ + Thrown if the requested cron schedule has invalid syntax. + """ + + +class TriggerNameAlreadyExistsException(DomainException): + """ + Thrown if the requested name already exists in the trigger repository + """ + + +class StreamPutException(DomainException): + """ + Thrown if the streaming storage gateway fails to put a record. + """ + + +class PostInferenceHooksException(DomainException): + """ + Thrown if the post inference hooks are invalid. + """ + + +class LatestImageTagNotFoundException(DomainException): + """ + Thrown if the latest image tag cannot be found. + """ + + +@dataclass +class FailToInferHardwareException(DomainException): + """ + Thrown if failed to infer hardware. + """ + + message: str diff --git a/server/llm_engine_server/domain/gateways/__init__.py b/model-engine/model_engine_server/domain/gateways/__init__.py similarity index 71% rename from server/llm_engine_server/domain/gateways/__init__.py rename to model-engine/model_engine_server/domain/gateways/__init__.py index caa45f72..9550da56 100644 --- a/server/llm_engine_server/domain/gateways/__init__.py +++ b/model-engine/model_engine_server/domain/gateways/__init__.py @@ -1,5 +1,9 @@ from .async_model_endpoint_inference_gateway import AsyncModelEndpointInferenceGateway +from .cron_job_gateway import CronJobGateway from .docker_image_batch_job_gateway import DockerImageBatchJobGateway +from .file_storage_gateway import FileStorageGateway +from .inference_autoscaling_metrics_gateway import InferenceAutoscalingMetricsGateway +from .llm_artifact_gateway import LLMArtifactGateway from .model_endpoints_schema_gateway import ModelEndpointsSchemaGateway from .model_primitive_gateway import ModelPrimitiveGateway from .monitoring_metrics_gateway import MonitoringMetricsGateway @@ -9,7 +13,11 @@ __all__ = ( "AsyncModelEndpointInferenceGateway", + "CronJobGateway", "DockerImageBatchJobGateway", + "FileStorageGateway", + "InferenceAutoscalingMetricsGateway", + "LLMArtifactGateway", "ModelEndpointsSchemaGateway", "ModelPrimitiveGateway", "MonitoringMetricsGateway", diff --git a/server/llm_engine_server/domain/gateways/async_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/domain/gateways/async_model_endpoint_inference_gateway.py similarity index 88% rename from server/llm_engine_server/domain/gateways/async_model_endpoint_inference_gateway.py rename to model-engine/model_engine_server/domain/gateways/async_model_endpoint_inference_gateway.py index 7aebae40..bff654c2 100644 --- a/server/llm_engine_server/domain/gateways/async_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/async_model_endpoint_inference_gateway.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod -from llm_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME +from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, EndpointPredictV1Request, GetAsyncTaskV1Response, diff --git a/model-engine/model_engine_server/domain/gateways/cron_job_gateway.py b/model-engine/model_engine_server/domain/gateways/cron_job_gateway.py new file mode 100644 index 00000000..c4bb289b --- /dev/null +++ b/model-engine/model_engine_server/domain/gateways/cron_job_gateway.py @@ -0,0 +1,98 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob + + +class CronJobGateway(ABC): + """ + Base class for K8s CronJob Gateway + """ + + @abstractmethod + async def create_cronjob( + self, + *, + request_host: str, + trigger_id: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Dict[str, str], + ) -> None: + """ + Create a cron job from a bundle and trigger. + + Args: + request_host: URL to forward the batch job creation request + trigger_id: The ID of the trigger + created_by: The user who created the trigger + owner: The user who owns the trigger + cron_schedule: Cron-formatted string representing the cron job's invocation schedule + docker_image_batch_job_bundle_id: The ID of the docker image batch job bundle + default_job_config: The user-specified input to the batch job. Exposed as a file mounted at mount_location to the batch job + job_config: K8s team/product labels + resource_requests: The resource requests for the batch job + + Returns: + None + """ + pass + + @abstractmethod + async def list_jobs( + self, + *, + owner: str, + trigger_id: Optional[str], + ) -> List[DockerImageBatchJob]: + """ + Lists all docker image batch jobs spawned by the trigger with the given ID, otherwise by owner if trigger_id is None + + Args: + trigger_id: the ID of the trigger pointing to the cron job + + Returns: + List of docker image batch jobs spawned by the trigger with the given ID, otherwise by owner if trigger_id is None + """ + pass + + @abstractmethod + async def update_cronjob( + self, + *, + trigger_id: str, + cron_schedule: Optional[str], + suspend: Optional[bool], + ) -> None: + """ + Partially updates the schedule field and/or the suspend field of the specified cron job. + + Args: + trigger_id: the ID of the trigger pointing to the cron job + cron_schedule: New cron schedule parameter representing the cron job's invocation schedule + suspend: The active status of the trigger, False means paused and True means unpaused + + Returns: + None + """ + pass + + @abstractmethod + async def delete_cronjob( + self, + *, + trigger_id: str, + ) -> None: + """ + Deletes the specified cron job. + + Args: + trigger_id: the ID of the trigger pointing to the cron job + + Returns: + None + """ + pass diff --git a/server/llm_engine_server/domain/gateways/docker_image_batch_job_gateway.py b/model-engine/model_engine_server/domain/gateways/docker_image_batch_job_gateway.py similarity index 81% rename from server/llm_engine_server/domain/gateways/docker_image_batch_job_gateway.py rename to model-engine/model_engine_server/domain/gateways/docker_image_batch_job_gateway.py index eafd355c..66c23368 100644 --- a/server/llm_engine_server/domain/gateways/docker_image_batch_job_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/docker_image_batch_job_gateway.py @@ -1,14 +1,13 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from llm_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests -from llm_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob class DockerImageBatchJobGateway(ABC): """ Base class for docker image batch job gateway - """ @abstractmethod @@ -25,6 +24,9 @@ async def create_docker_image_batch_job( resource_requests: CreateDockerImageBatchJobResourceRequests, labels: Dict[str, str], mount_location: Optional[str], + annotations: Optional[Dict[str, str]] = None, + override_job_max_runtime_s: Optional[int] = None, + num_workers: Optional[int] = 1, ) -> str: """ Create a docker image batch job @@ -38,8 +40,11 @@ async def create_docker_image_batch_job( repo: The ECR repo where the docker image running the batch job lies tag: The tag of the docker image labels: K8s team/product labels + annotations: K8s annotations resource_requests: The resource requests for the batch job. mount_location: Location on filesystem where runtime-provided file contents get mounted + override_job_max_runtime_s: Optional override for the maximum runtime of the job + num_workers: num of pods to run in this job. Coordination needs to happen between the workers. Returns: diff --git a/model-engine/model_engine_server/domain/gateways/file_storage_gateway.py b/model-engine/model_engine_server/domain/gateways/file_storage_gateway.py new file mode 100644 index 00000000..b76fdd77 --- /dev/null +++ b/model-engine/model_engine_server/domain/gateways/file_storage_gateway.py @@ -0,0 +1,94 @@ +from abc import ABC, abstractmethod +from typing import List, Optional + +from model_engine_server.domain.entities import FileMetadata + + +class FileStorageGateway(ABC): + """ + Base class for file storage gateway + """ + + @abstractmethod + async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]: + """ + Get file URL from file ID + + Args: + owner: The user who owns the file. + file_id: The ID of the file. + + Returns: + The URL of the file, or None if it does not exist. + """ + pass + + @abstractmethod + async def upload_file(self, owner: str, filename: str, content: bytes) -> str: + """ + Upload a file + + Args: + owner: The user who owns the file. + filename: The name of the file. + content: The content of the file. + + Returns: + The ID of the file. + """ + pass + + @abstractmethod + async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: + """ + Get metadata about a file. + + Args: + owner: The user who owns the file. + file_id: The ID of the file. + + Returns: + Information about the file, or None if it does not exist. + """ + pass + + @abstractmethod + async def list_files(self, owner: str) -> List[FileMetadata]: + """ + List all files for a given owner. + + Args: + owner: The owner whose files to list. + + Returns: + The list of files. + """ + pass + + @abstractmethod + async def delete_file(self, owner: str, file_id: str) -> bool: + """ + Delete a file. + + Args: + owner: The user who owns the files. + file_id: The ID of the file. + + Returns: + Whether the file was deleted successfully. + """ + pass + + @abstractmethod + async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: + """ + Get a file's content. + + Args: + owner: The user who owns the file. + file_id: The ID of the file. + + Returns: + The content of the file, or None if it does not exist. + """ + pass diff --git a/model-engine/model_engine_server/domain/gateways/inference_autoscaling_metrics_gateway.py b/model-engine/model_engine_server/domain/gateways/inference_autoscaling_metrics_gateway.py new file mode 100644 index 00000000..862b6e05 --- /dev/null +++ b/model-engine/model_engine_server/domain/gateways/inference_autoscaling_metrics_gateway.py @@ -0,0 +1,36 @@ +from abc import ABC, abstractmethod + + +class InferenceAutoscalingMetricsGateway(ABC): + """ + Abstract Base Class for a gateway that emits autoscaling metrics for inference requests. Can be used in conjunction + with various autoscaler resources, e.g. a Keda ScaledObject, to autoscale inference endpoints. + """ + + @abstractmethod + async def emit_inference_autoscaling_metric(self, endpoint_id: str): + """ + On an inference request, emit a metric + """ + pass + + @abstractmethod + async def emit_prewarm_metric(self, endpoint_id: str): + """ + If you want to prewarm an endpoint, emit a metric here + """ + pass + + @abstractmethod + async def create_or_update_resources(self, endpoint_id: str): + """ + Create necessary resources for autoscaling metrics + """ + pass + + @abstractmethod + async def delete_resources(self, endpoint_id: str): + """ + Delete necessary resources for autoscaling metrics + """ + pass diff --git a/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py new file mode 100644 index 00000000..8f8ece69 --- /dev/null +++ b/model-engine/model_engine_server/domain/gateways/llm_artifact_gateway.py @@ -0,0 +1,51 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List + + +class LLMArtifactGateway(ABC): + """ + Abstract Base Class for interacting with llm artifacts. + """ + + @abstractmethod + def list_files(self, path: str, **kwargs) -> List[str]: + """ + Gets a list of files from a given path. + + Args: + path (str): path to list files + """ + pass + + @abstractmethod + def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + """ + Download files from a given path to a target path. + + Args: + path (str): path to list files + target_path (str): local path to download files + overwrite (bool): whether to overwrite existing local files + """ + pass + + @abstractmethod + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: + """ + Gets a list of URLs for all files associated with a given model. + + Args: + owner (str): owner of the model + model_name (str): name of the model + """ + pass + + @abstractmethod + def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: + """ + Gets the model config from the model files live at given folder. + + Args: + path (str): path to model files + """ + pass diff --git a/server/llm_engine_server/domain/gateways/model_endpoints_schema_gateway.py b/model-engine/model_engine_server/domain/gateways/model_endpoints_schema_gateway.py similarity index 86% rename from server/llm_engine_server/domain/gateways/model_endpoints_schema_gateway.py rename to model-engine/model_engine_server/domain/gateways/model_endpoints_schema_gateway.py index dd2347d1..fea71580 100644 --- a/server/llm_engine_server/domain/gateways/model_endpoints_schema_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/model_endpoints_schema_gateway.py @@ -1,7 +1,7 @@ from abc import abstractmethod from typing import Sequence -from llm_engine_server.domain.entities import ModelEndpointRecord, ModelEndpointsSchema +from model_engine_server.domain.entities import ModelEndpointRecord, ModelEndpointsSchema class ModelEndpointsSchemaGateway: diff --git a/server/llm_engine_server/domain/gateways/model_primitive_gateway.py b/model-engine/model_engine_server/domain/gateways/model_primitive_gateway.py similarity index 84% rename from server/llm_engine_server/domain/gateways/model_primitive_gateway.py rename to model-engine/model_engine_server/domain/gateways/model_primitive_gateway.py index 365aadf5..0b58e17d 100644 --- a/server/llm_engine_server/domain/gateways/model_primitive_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/model_primitive_gateway.py @@ -1,12 +1,12 @@ from abc import ABC, abstractmethod from typing import Optional -from llm_engine_server.domain.entities.model_bundle_entity import ModelBundleFrameworkType +from model_engine_server.domain.entities.model_bundle_entity import ModelBundleFrameworkType class ModelPrimitiveGateway(ABC): """ - Base class for interactions with Scale Model Primitive. + Base class for interactions with Model Primitive. """ @abstractmethod diff --git a/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py new file mode 100644 index 00000000..dcad95d5 --- /dev/null +++ b/model-engine/model_engine_server/domain/gateways/monitoring_metrics_gateway.py @@ -0,0 +1,92 @@ +""" +For emitting external monitoring metrics to some sort of store e.g. datadog +Currently distinct from something emitting to a Metrics Store + +Used to calculate proportion of successful/unsuccessful requests, differentiates between +docker build vs other errors +""" + +from abc import ABC, abstractmethod +from typing import Optional + +from model_engine_server.common.dtos.llms import TokenUsage +from model_engine_server.common.pydantic_types import BaseModel +from model_engine_server.core.auth.authentication_repository import User + + +class MetricMetadata(BaseModel): + user: User + model_name: Optional[str] = None + + +class MonitoringMetricsGateway(ABC): + @abstractmethod + def emit_attempted_build_metric(self): + """ + Service builder attempted metric + """ + + @abstractmethod + def emit_successful_build_metric(self): + """ + Service builder succeeded metric + """ + + @abstractmethod + def emit_build_time_metric(self, duration_seconds: float): + """ + Service builder build time metric + """ + + @abstractmethod + def emit_image_build_cache_hit_metric(self, image_type: str): + """ + Service builder image build cache hit metric + """ + + @abstractmethod + def emit_image_build_cache_miss_metric(self, image_type: str): + """ + Service builder image build cache miss metric + """ + + @abstractmethod + def emit_docker_failed_build_metric(self): + """ + Service builder docker build failed metric + """ + + @abstractmethod + def emit_database_cache_hit_metric(self): + """ + Successful database cache metric + """ + + @abstractmethod + def emit_database_cache_miss_metric(self): + """ + Missed database cache metric + """ + + @abstractmethod + def emit_route_call_metric(self, route: str, metadata: MetricMetadata): + """ + Route call metric + """ + pass + + @abstractmethod + def emit_token_count_metrics(self, token_usage: TokenUsage, metadata: MetricMetadata): + """ + Token count metrics + """ + pass + + @abstractmethod + def emit_http_call_error_metrics(self, endpoint_name: str, error_code: int): + """ + Sync call timeout/error metrics, emitted when sync/streaming request + times out or otherwise errors (likely due to scale-from-zero not being + fast enough, or we're out of capacity, or the upstream svc is unhealthy) + """ + pass diff --git a/server/llm_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py similarity index 72% rename from server/llm_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py rename to model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py index cd4ded50..5cb99973 100644 --- a/server/llm_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/streaming_model_endpoint_inference_gateway.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod -from typing import AsyncIterable +from typing import AsyncIterable, Optional -from llm_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, +from model_engine_server.common.dtos.tasks import ( + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, ) @@ -17,7 +17,11 @@ class StreamingModelEndpointInferenceGateway(ABC): @abstractmethod def streaming_predict( - self, topic: str, predict_request: EndpointPredictV1Request + self, + topic: str, + predict_request: SyncEndpointPredictV1Request, + manually_resolve_dns: bool, + endpoint_name: Optional[str] = None, ) -> AsyncIterable[SyncEndpointPredictV1Response]: """ Runs a prediction request and returns a streaming response. diff --git a/server/llm_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py similarity index 72% rename from server/llm_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py rename to model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py index 0ad6921c..2bc9631e 100644 --- a/server/llm_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/sync_model_endpoint_inference_gateway.py @@ -1,7 +1,8 @@ from abc import ABC, abstractmethod +from typing import Optional -from llm_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, +from model_engine_server.common.dtos.tasks import ( + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, ) @@ -16,7 +17,11 @@ class SyncModelEndpointInferenceGateway(ABC): @abstractmethod async def predict( - self, topic: str, predict_request: EndpointPredictV1Request + self, + topic: str, + predict_request: SyncEndpointPredictV1Request, + manually_resolve_dns: bool, + endpoint_name: Optional[str] = None, ) -> SyncEndpointPredictV1Response: """ Runs a prediction request and returns a response. diff --git a/server/llm_engine_server/domain/gateways/task_queue_gateway.py b/model-engine/model_engine_server/domain/gateways/task_queue_gateway.py similarity index 91% rename from server/llm_engine_server/domain/gateways/task_queue_gateway.py rename to model-engine/model_engine_server/domain/gateways/task_queue_gateway.py index a667b813..bf41892f 100644 --- a/server/llm_engine_server/domain/gateways/task_queue_gateway.py +++ b/model-engine/model_engine_server/domain/gateways/task_queue_gateway.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from llm_engine_server.common.dtos.tasks import CreateAsyncTaskV1Response, GetAsyncTaskV1Response +from model_engine_server.common.dtos.tasks import CreateAsyncTaskV1Response, GetAsyncTaskV1Response class TaskQueueGateway(ABC): diff --git a/server/llm_engine_server/domain/repositories/__init__.py b/model-engine/model_engine_server/domain/repositories/__init__.py similarity index 57% rename from server/llm_engine_server/domain/repositories/__init__.py rename to model-engine/model_engine_server/domain/repositories/__init__.py index 96236895..56ec32e7 100644 --- a/server/llm_engine_server/domain/repositories/__init__.py +++ b/model-engine/model_engine_server/domain/repositories/__init__.py @@ -2,10 +2,16 @@ from .docker_image_batch_job_bundle_repository import DockerImageBatchJobBundleRepository from .docker_repository import DockerRepository +from .llm_fine_tune_events_repository import LLMFineTuneEventsRepository from .model_bundle_repository import ModelBundleRepository +from .tokenizer_repository import TokenizerRepository +from .trigger_repository import TriggerRepository __all__: Sequence[str] = [ "DockerRepository", "DockerImageBatchJobBundleRepository", + "LLMFineTuneEventsRepository", "ModelBundleRepository", + "TokenizerRepository", + "TriggerRepository", ] diff --git a/server/llm_engine_server/domain/repositories/docker_image_batch_job_bundle_repository.py b/model-engine/model_engine_server/domain/repositories/docker_image_batch_job_bundle_repository.py similarity index 93% rename from server/llm_engine_server/domain/repositories/docker_image_batch_job_bundle_repository.py rename to model-engine/model_engine_server/domain/repositories/docker_image_batch_job_bundle_repository.py index 019f4150..eb3c7318 100644 --- a/server/llm_engine_server/domain/repositories/docker_image_batch_job_bundle_repository.py +++ b/model-engine/model_engine_server/domain/repositories/docker_image_batch_job_bundle_repository.py @@ -1,9 +1,9 @@ from abc import ABC, abstractmethod from typing import Dict, List, Optional, Sequence -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.domain.entities import GpuType -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.domain.entities import GpuType +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) diff --git a/server/llm_engine_server/domain/repositories/docker_repository.py b/model-engine/model_engine_server/domain/repositories/docker_repository.py similarity index 80% rename from server/llm_engine_server/domain/repositories/docker_repository.py rename to model-engine/model_engine_server/domain/repositories/docker_repository.py index 184fd9da..f8ba774c 100644 --- a/server/llm_engine_server/domain/repositories/docker_repository.py +++ b/model-engine/model_engine_server/domain/repositories/docker_repository.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Optional -from llm_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse class DockerRepository(ABC): @@ -49,6 +49,17 @@ def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: """ pass + @abstractmethod + def get_latest_image_tag(self, repository_name: str) -> str: + """ + Returns the Docker image tag of the most recently pushed image in the given repository + + Args: + repository_name: the name of the repository containing the image. + + Returns: the tag of the latest Docker image. + """ + def is_repo_name(self, repo_name: str): # We assume repository names must start with a letter and can only contain lowercase letters, numbers, hyphens, underscores, and forward slashes. # Based-off ECR naming standards diff --git a/model-engine/model_engine_server/domain/repositories/llm_fine_tune_events_repository.py b/model-engine/model_engine_server/domain/repositories/llm_fine_tune_events_repository.py new file mode 100644 index 00000000..004739ab --- /dev/null +++ b/model-engine/model_engine_server/domain/repositories/llm_fine_tune_events_repository.py @@ -0,0 +1,16 @@ +from abc import ABC, abstractmethod +from typing import List + +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneEvent + + +class LLMFineTuneEventsRepository(ABC): + @abstractmethod + async def get_fine_tune_events( + self, user_id: str, model_endpoint_name: str + ) -> List[LLMFineTuneEvent]: + pass + + @abstractmethod + async def initialize_events(self, user_id: str, model_endpoint_name: str) -> None: + pass diff --git a/server/llm_engine_server/domain/repositories/model_bundle_repository.py b/model-engine/model_engine_server/domain/repositories/model_bundle_repository.py similarity index 95% rename from server/llm_engine_server/domain/repositories/model_bundle_repository.py rename to model-engine/model_engine_server/domain/repositories/model_bundle_repository.py index 2dc74c9d..067df488 100644 --- a/server/llm_engine_server/domain/repositories/model_bundle_repository.py +++ b/model-engine/model_engine_server/domain/repositories/model_bundle_repository.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional, Sequence -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.domain.entities import ( ModelBundle, ModelBundleFlavors, ModelBundlePackagingType, diff --git a/model-engine/model_engine_server/domain/repositories/tokenizer_repository.py b/model-engine/model_engine_server/domain/repositories/tokenizer_repository.py new file mode 100644 index 00000000..f8ba740a --- /dev/null +++ b/model-engine/model_engine_server/domain/repositories/tokenizer_repository.py @@ -0,0 +1,18 @@ +from abc import ABC, abstractmethod + +from transformers import AutoTokenizer + + +class TokenizerRepository(ABC): + @abstractmethod + def load_tokenizer(self, model_name: str) -> AutoTokenizer: + """ + Loads a tokenizer from a model name. + + Args: + model_name: The model name to load the tokenizer for. + + Returns: + A tokenizer. + """ + pass diff --git a/model-engine/model_engine_server/domain/repositories/trigger_repository.py b/model-engine/model_engine_server/domain/repositories/trigger_repository.py new file mode 100644 index 00000000..a8fd096f --- /dev/null +++ b/model-engine/model_engine_server/domain/repositories/trigger_repository.py @@ -0,0 +1,96 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, Optional, Sequence + +from model_engine_server.domain.entities.trigger_entity import Trigger + + +class TriggerRepository(ABC): + @abstractmethod + async def create_trigger( + self, + *, + name: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Optional[Dict[str, str]], + ) -> Trigger: + """ + Creates a trigger. + Args: + name: User-set name of trigger + created_by: User creating trigger + owner: Team owning trigger + cron_schedule: Schedule of k8s CronJob + docker_image_batch_job_bundle_id: ID of docker image batch job bundle used by trigger + default_job_config: Optional config to specify parameters injected at runtime + default_job_metadata: Optional metdata tags for k8s jobs spawned by trigger + + Returns: + A trigger entity + """ + pass + + @abstractmethod + async def list_triggers( + self, + owner: str, + ) -> Sequence[Trigger]: + """ + Lists all triggers with a given owner + Args: + owner: Owner of trigger(s) + + Returns: + Sequence of trigger entities + """ + pass + + @abstractmethod + async def get_trigger( + self, + trigger_id: str, + ) -> Optional[Trigger]: + """ + Retrieves a single trigger by ID + Args: + trigger_id: ID of trigger we want + + Returns: + Associated trigger entity or None if we couldn't find it + """ + pass + + @abstractmethod + async def update_trigger( + self, + trigger_id: str, + cron_schedule: str, + ) -> bool: + """ + Updates the specified trigger's cron schedule + Args: + trigger_id: ID of trigger we want + cron_schedule: new cron schedule to replace the original + + Returns: + True or False, whether the update of the trigger was successful or not + """ + pass + + @abstractmethod + async def delete_trigger( + self, + trigger_id: str, + ) -> bool: + """ + Deletes the specified trigger + Args: + trigger_id: ID of trigger we want to delete + + Returns: + True or False, whether the deletion of the trigger was successful or not + """ + pass diff --git a/server/llm_engine_server/domain/services/__init__.py b/model-engine/model_engine_server/domain/services/__init__.py similarity index 82% rename from server/llm_engine_server/domain/services/__init__.py rename to model-engine/model_engine_server/domain/services/__init__.py index 723f62db..508a68e1 100644 --- a/server/llm_engine_server/domain/services/__init__.py +++ b/model-engine/model_engine_server/domain/services/__init__.py @@ -2,6 +2,7 @@ from .batch_job_service import BatchJobService from .endpoint_builder_service import EndpointBuilderService +from .llm_fine_tuning_service import LLMFineTuningService from .llm_model_endpoint_service import LLMModelEndpointService from .model_endpoint_service import ModelEndpointService @@ -9,5 +10,6 @@ "BatchJobService", "EndpointBuilderService", "LLMModelEndpointService", + "LLMFineTuningService", "ModelEndpointService", ] diff --git a/server/llm_engine_server/domain/services/batch_job_service.py b/model-engine/model_engine_server/domain/services/batch_job_service.py similarity index 92% rename from server/llm_engine_server/domain/services/batch_job_service.py rename to model-engine/model_engine_server/domain/services/batch_job_service.py index 9e92843d..4bac6e63 100644 --- a/server/llm_engine_server/domain/services/batch_job_service.py +++ b/model-engine/model_engine_server/domain/services/batch_job_service.py @@ -1,8 +1,8 @@ from abc import ABC, abstractmethod from typing import Dict, Optional -from llm_engine_server.common.dtos.batch_jobs import CreateBatchJobResourceRequests -from llm_engine_server.domain.entities import BatchJob, BatchJobSerializationFormat +from model_engine_server.common.dtos.batch_jobs import CreateBatchJobResourceRequests +from model_engine_server.domain.entities import BatchJob, BatchJobSerializationFormat class BatchJobService(ABC): diff --git a/server/llm_engine_server/domain/services/endpoint_builder_service.py b/model-engine/model_engine_server/domain/services/endpoint_builder_service.py similarity index 89% rename from server/llm_engine_server/domain/services/endpoint_builder_service.py rename to model-engine/model_engine_server/domain/services/endpoint_builder_service.py index ff521817..4b9079e0 100644 --- a/server/llm_engine_server/domain/services/endpoint_builder_service.py +++ b/model-engine/model_engine_server/domain/services/endpoint_builder_service.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from llm_engine_server.common.dtos.endpoint_builder import ( +from model_engine_server.common.dtos.endpoint_builder import ( BuildEndpointRequest, BuildEndpointResponse, ) diff --git a/model-engine/model_engine_server/domain/services/llm_batch_completions_service.py b/model-engine/model_engine_server/domain/services/llm_batch_completions_service.py new file mode 100644 index 00000000..ffc0eed9 --- /dev/null +++ b/model-engine/model_engine_server/domain/services/llm_batch_completions_service.py @@ -0,0 +1,88 @@ +from abc import ABC, abstractmethod +from typing import Dict, Optional + +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.llms import CreateBatchCompletionsEngineRequest +from model_engine_server.common.dtos.llms.batch_completion import ( + BatchCompletionsJob, + UpdateBatchCompletionsV2Request, +) +from model_engine_server.core.auth.authentication_repository import User + + +class LLMBatchCompletionsService(ABC): + """ + Base class for LLM batch completions services. + """ + + @abstractmethod + async def create_batch_job( + self, + *, + user: User, + image_repo: str, + image_tag: str, + job_request: CreateBatchCompletionsEngineRequest, + resource_requests: CreateDockerImageBatchJobResourceRequests, + max_runtime_sec: int = 24 * 60 * 60, + labels: Dict[str, str] = {}, + num_workers: Optional[int] = 1, + ) -> BatchCompletionsJob: + """ + Create a batch completion job. + + Args: + owner: The user who requested the batch job + image_repo: The docker repo where the image is stored + image_tag: The tag of the batch completions image + job_config: The user-specified input to the batch job. Exposed as a file mounted at mount_location to the batch job + labels: Labels to apply to the batch job. + resource_requests: The resource requests for the batch job. + max_runtime_sec: The timeout of the batch job in seconds. + num_workers: The number of workers to run in the job. + + Returns: + The ID of the batch job. + """ + pass + + @abstractmethod + async def get_batch_job(self, batch_job_id: str, user: User) -> Optional[BatchCompletionsJob]: + """ + Get a batch job. + + Args: + batch_job_id: The ID of the batch job. + + Returns: + The batch job, or None if it does not exist. + """ + pass + + @abstractmethod + async def update_batch_job( + self, batch_job_id: str, request: UpdateBatchCompletionsV2Request, user: User + ) -> Optional[BatchCompletionsJob]: + """ + Get a batch job. + + Args: + batch_job_id: The ID of the batch job. + + Returns: + The batch job, or None if it does not exist. + """ + pass + + @abstractmethod + async def cancel_batch_job(self, batch_job_id: str, user: User) -> bool: + """ + Update a batch job. + + Args: + batch_job_id: The ID of the batch job. + + Returns: + Whether the batch job was updated successfully. + """ + pass diff --git a/model-engine/model_engine_server/domain/services/llm_fine_tuning_service.py b/model-engine/model_engine_server/domain/services/llm_fine_tuning_service.py new file mode 100644 index 00000000..e55d0527 --- /dev/null +++ b/model-engine/model_engine_server/domain/services/llm_fine_tuning_service.py @@ -0,0 +1,40 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, List, Optional + +from model_engine_server.domain.entities import FineTuneHparamValueType +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob + + +class LLMFineTuningService(ABC): + @abstractmethod + async def create_fine_tune( + self, + created_by: str, + owner: str, + model: str, + training_file: str, + validation_file: Optional[str], + fine_tuning_method: str, + hyperparameters: Dict[str, FineTuneHparamValueType], + fine_tuned_model: str, + wandb_config: Optional[Dict[str, Any]], + ) -> str: + pass + + @abstractmethod + async def get_fine_tune(self, owner: str, fine_tune_id: str) -> Optional[DockerImageBatchJob]: + pass + + @abstractmethod + async def list_fine_tunes(self, owner: str) -> List[DockerImageBatchJob]: + pass + + @abstractmethod + async def cancel_fine_tune(self, owner: str, fine_tune_id: str) -> bool: + pass + + @abstractmethod + async def get_fine_tune_model_name_from_id( + self, owner: str, fine_tune_id: str + ) -> Optional[str]: + pass diff --git a/server/llm_engine_server/domain/services/llm_model_endpoint_service.py b/model-engine/model_engine_server/domain/services/llm_model_endpoint_service.py similarity index 89% rename from server/llm_engine_server/domain/services/llm_model_endpoint_service.py rename to model-engine/model_engine_server/domain/services/llm_model_endpoint_service.py index f06279d5..07a7f9f6 100644 --- a/server/llm_engine_server/domain/services/llm_model_endpoint_service.py +++ b/model-engine/model_engine_server/domain/services/llm_model_endpoint_service.py @@ -2,8 +2,8 @@ from abc import ABC, abstractmethod from typing import List, Optional -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.domain.entities import ModelEndpoint +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.domain.entities import ModelEndpoint class LLMModelEndpointService(ABC): diff --git a/server/llm_engine_server/domain/services/model_endpoint_service.py b/model-engine/model_engine_server/domain/services/model_endpoint_service.py similarity index 89% rename from server/llm_engine_server/domain/services/model_endpoint_service.py rename to model-engine/model_engine_server/domain/services/model_endpoint_service.py index b53b5ace..83a5cc2e 100644 --- a/server/llm_engine_server/domain/services/model_endpoint_service.py +++ b/model-engine/model_engine_server/domain/services/model_endpoint_service.py @@ -2,8 +2,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, GpuType, @@ -13,11 +13,14 @@ ModelEndpointType, StorageSpecificationType, ) -from llm_engine_server.domain.gateways import ( +from model_engine_server.domain.gateways import ( AsyncModelEndpointInferenceGateway, StreamingModelEndpointInferenceGateway, SyncModelEndpointInferenceGateway, ) +from model_engine_server.domain.gateways.inference_autoscaling_metrics_gateway import ( + InferenceAutoscalingMetricsGateway, +) class ModelEndpointService(ABC): @@ -49,6 +52,14 @@ def get_streaming_model_endpoint_inference_gateway( Returns the sync model endpoint inference gateway. """ + @abstractmethod + def get_inference_autoscaling_metrics_gateway( + self, + ) -> InferenceAutoscalingMetricsGateway: + """ + Returns the inference autoscaling metrics gateway. + """ + @abstractmethod async def create_model_endpoint( self, @@ -64,7 +75,8 @@ async def create_model_endpoint( gpus: int, memory: StorageSpecificationType, gpu_type: Optional[GpuType], - storage: Optional[StorageSpecificationType], + storage: StorageSpecificationType, + nodes_per_worker: int, optimize_costs: bool, min_workers: int, max_workers: int, @@ -74,6 +86,7 @@ async def create_model_endpoint( results_s3_bucket: str, prewarm: bool, high_priority: Optional[bool], + billing_tags: Optional[Dict[str, Any]] = None, owner: str, default_callback_url: Optional[str], default_callback_auth: Optional[CallbackAuth], @@ -106,6 +119,7 @@ async def create_model_endpoint( to False high_priority: Makes all pods for this endpoint higher priority to enable faster pod spinup time. Higher priority pods will displace the lower priority dummy pods from shared pool. + billing_tags: Tags that get passed to scale's billing infra owner: The team ID of the creator of the model endpoint. default_callback_url: The default callback URL to use for the model endpoint. default_callback_auth: The default callback auth to use for the model endpoint. @@ -202,6 +216,7 @@ async def update_model_endpoint( labels: Optional[Dict[str, str]] = None, prewarm: Optional[bool] = None, high_priority: Optional[bool] = None, + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth] = None, public_inference: Optional[bool] = None, @@ -229,6 +244,7 @@ async def update_model_endpoint( to False high_priority: Makes all pods for this endpoint higher priority to enable faster pod spinup time. Higher priority pods will displace the lower priority dummy pods from shared pool. + billing_tags: Tags that get passed to scale's billing infra default_callback_url: The default callback URL to use for the model endpoint. default_callback_auth: The default callback auth to use for the model endpoint. public_inference: Whether to allow public inference. @@ -240,3 +256,11 @@ async def update_model_endpoint( ExistingEndpointOperationInProgressException: if the endpoint is currently being edited (corresponds to an HTTP 409) """ + + @abstractmethod + def can_scale_http_endpoint_from_zero(self) -> bool: + """ + Returns whether the service can autoscale sync/stream endpoints from zero. + For instance, if particular dependencies in the cluster are not installed, then this should + return False + """ diff --git a/server/llm_engine_server/inference/async_inference/__init__.py b/model-engine/model_engine_server/domain/use_cases/__init__.py similarity index 100% rename from server/llm_engine_server/inference/async_inference/__init__.py rename to model-engine/model_engine_server/domain/use_cases/__init__.py diff --git a/server/llm_engine_server/domain/use_cases/async_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/async_inference_use_cases.py similarity index 84% rename from server/llm_engine_server/domain/use_cases/async_inference_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/async_inference_use_cases.py index a0a3ec9d..647905f2 100644 --- a/server/llm_engine_server/domain/use_cases/async_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/async_inference_use_cases.py @@ -1,19 +1,19 @@ -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, EndpointPredictV1Request, GetAsyncTaskV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, +) +from model_engine_server.domain.entities import ModelEndpointType +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, -) -from llm_engine_server.domain.entities import ModelEndpointType -from llm_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException -from llm_engine_server.domain.services.model_endpoint_service import ModelEndpointService +from model_engine_server.domain.services.model_endpoint_service import ModelEndpointService DEFAULT_TASK_TIMEOUT_SECONDS = 86400 @@ -25,7 +25,7 @@ class CreateAsyncInferenceTaskV1UseCase: def __init__(self, model_endpoint_service: ModelEndpointService): self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, model_endpoint_id: str, request: EndpointPredictV1Request diff --git a/server/llm_engine_server/domain/use_cases/batch_job_use_cases.py b/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py similarity index 80% rename from server/llm_engine_server/domain/use_cases/batch_job_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py index e6710313..eeb7ce67 100644 --- a/server/llm_engine_server/domain/use_cases/batch_job_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/batch_job_use_cases.py @@ -1,6 +1,7 @@ from datetime import datetime +from typing import Optional -from llm_engine_server.common.dtos.batch_jobs import ( +from model_engine_server.common.dtos.batch_jobs import ( CreateBatchJobV1Request, CreateBatchJobV1Response, CreateDockerImageBatchJobResourceRequests, @@ -8,38 +9,39 @@ CreateDockerImageBatchJobV1Response, GetBatchJobV1Response, GetDockerImageBatchJobV1Response, + ListDockerImageBatchJobsV1Response, UpdateBatchJobV1Request, UpdateBatchJobV1Response, UpdateDockerImageBatchJobV1Request, UpdateDockerImageBatchJobV1Response, ) -from llm_engine_server.common.resource_limits import validate_resource_requests -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.resource_limits import validate_resource_requests +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, +) +from model_engine_server.domain.entities import ModelEndpointType +from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, -) -from llm_engine_server.domain.entities import ModelEndpointType -from llm_engine_server.domain.gateways.docker_image_batch_job_gateway import ( - DockerImageBatchJobGateway, -) -from llm_engine_server.domain.repositories import ( +from model_engine_server.domain.gateways import CronJobGateway, DockerImageBatchJobGateway +from model_engine_server.domain.repositories import ( DockerImageBatchJobBundleRepository, DockerRepository, ModelBundleRepository, + TriggerRepository, ) -from llm_engine_server.domain.services import BatchJobService, ModelEndpointService -from llm_engine_server.domain.use_cases.model_endpoint_use_cases import ( +from model_engine_server.domain.services import BatchJobService, ModelEndpointService +from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( validate_deployment_resources, + validate_labels, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class CreateBatchJobV1UseCase: @@ -56,15 +58,17 @@ def __init__( self.batch_job_service = batch_job_service self.model_bundle_repository = model_bundle_repository self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, request: CreateBatchJobV1Request ) -> CreateBatchJobV1Response: + validate_labels(request.labels) validate_deployment_resources( min_workers=0, max_workers=request.resource_requests.max_workers, endpoint_type=ModelEndpointType.ASYNC, + can_scale_http_endpoint_from_zero=self.model_endpoint_service.can_scale_http_endpoint_from_zero(), ) bundle = await self.model_bundle_repository.get_model_bundle( @@ -109,7 +113,7 @@ class GetBatchJobV1UseCase: """ def __init__(self, batch_job_service: BatchJobService): - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() self.batch_job_service = batch_job_service async def execute(self, user: User, batch_job_id: str) -> GetBatchJobV1Response: @@ -143,7 +147,7 @@ class UpdateBatchJobV1UseCase: def __init__(self, batch_job_service: BatchJobService): self.batch_job_service = batch_job_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, batch_job_id: str, request: UpdateBatchJobV1Request @@ -170,12 +174,11 @@ def __init__( self.docker_image_batch_job_gateway = docker_image_batch_job_gateway self.docker_image_batch_job_bundle_repository = docker_image_batch_job_bundle_repository self.docker_repository = docker_repository - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, request: CreateDockerImageBatchJobV1Request ) -> CreateDockerImageBatchJobV1Response: - if request.docker_image_batch_job_bundle_id is not None: batch_bundle = await self.docker_image_batch_job_bundle_repository.get_docker_image_batch_job_bundle( request.docker_image_batch_job_bundle_id @@ -239,6 +242,16 @@ async def execute( gpu_type=final_requests.gpu_type, ) + validate_labels(request.labels) + + if ( + request.override_job_max_runtime_s is not None + and request.override_job_max_runtime_s < 1 + ): + raise ObjectHasInvalidValueException( + "Please supply a positive integer value for batch job's maximum runtime (`override_job_max_runtime_s`)" + ) + job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( created_by=user.user_id, owner=user.team_id, @@ -250,6 +263,7 @@ async def execute( resource_requests=final_requests, labels=request.labels, mount_location=batch_bundle.mount_location, + override_job_max_runtime_s=request.override_job_max_runtime_s, ) return CreateDockerImageBatchJobV1Response(job_id=job_id) @@ -278,6 +292,32 @@ async def execute(self, user: User, batch_job_id: str) -> GetDockerImageBatchJob return GetDockerImageBatchJobV1Response(status=job.status) +class ListDockerImageBatchJobsV1UseCase: + def __init__( + self, + trigger_repository: TriggerRepository, + cron_job_gateway: CronJobGateway, + ): + self.trigger_repository = trigger_repository + self.cron_job_gateway = cron_job_gateway + self.authz_module = LiveAuthorizationModule() + + async def execute( + self, user: User, trigger_id: Optional[str] + ) -> ListDockerImageBatchJobsV1Response: + if trigger_id: + trigger = await self.trigger_repository.get_trigger(trigger_id=trigger_id) + if trigger is None: + raise ObjectNotFoundException + if not self.authz_module.check_access_read_owned_entity(user, trigger): + raise ObjectNotAuthorizedException( + f"User {user} is not authorized for trigger {trigger_id}" + ) + + jobs = await self.cron_job_gateway.list_jobs(owner=user.team_id, trigger_id=trigger_id) + return ListDockerImageBatchJobsV1Response(jobs=jobs) + + class UpdateDockerImageBatchJobV1UseCase: """ Use case for cancelling a batch job. diff --git a/server/llm_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py b/model-engine/model_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py similarity index 90% rename from server/llm_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py index 77c97329..3767ffe5 100644 --- a/server/llm_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/docker_image_batch_job_bundle_use_cases.py @@ -1,21 +1,21 @@ from typing import Optional -from llm_engine_server.common.dtos.batch_jobs import ( +from model_engine_server.common.dtos.batch_jobs import ( CreateDockerImageBatchJobBundleV1Request, CreateDockerImageBatchJobBundleV1Response, DockerImageBatchJobBundleV1Response, ListDockerImageBatchJobBundleV1Response, ) -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, +) +from model_engine_server.domain.exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, -) -from llm_engine_server.domain.repositories import DockerImageBatchJobBundleRepository +from model_engine_server.domain.repositories import DockerImageBatchJobBundleRepository class CreateDockerImageBatchJobBundleV1UseCase: @@ -88,7 +88,7 @@ async def execute( class GetDockerImageBatchJobBundleByIdV1UseCase: def __init__(self, docker_image_batch_job_bundle_repo: DockerImageBatchJobBundleRepository): self.docker_image_batch_job_bundle_repo = docker_image_batch_job_bundle_repo - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, docker_image_batch_job_bundle_id: str diff --git a/model-engine/model_engine_server/domain/use_cases/file_use_cases.py b/model-engine/model_engine_server/domain/use_cases/file_use_cases.py new file mode 100644 index 00000000..47f3162a --- /dev/null +++ b/model-engine/model_engine_server/domain/use_cases/file_use_cases.py @@ -0,0 +1,97 @@ +from model_engine_server.common.dtos.files import ( + DeleteFileResponse, + GetFileContentResponse, + GetFileResponse, + ListFilesResponse, + UploadFileResponse, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ObjectNotFoundException +from model_engine_server.domain.gateways import FileStorageGateway + +logger = make_logger(logger_name()) + + +class UploadFileUseCase: + def __init__(self, file_storage_gateway: FileStorageGateway): + self.file_storage_gateway = file_storage_gateway + + async def execute(self, user: User, filename: str, content: bytes) -> UploadFileResponse: + file_id = await self.file_storage_gateway.upload_file( + owner=user.team_id, + filename=filename, + content=content, + ) + return UploadFileResponse( + id=file_id, + ) + + +class GetFileUseCase: + def __init__(self, file_storage_gateway: FileStorageGateway): + self.file_storage_gateway = file_storage_gateway + + async def execute(self, user: User, file_id: str) -> GetFileResponse: + file_metadata = await self.file_storage_gateway.get_file( + owner=user.team_id, + file_id=file_id, + ) + if file_metadata is None: + raise ObjectNotFoundException + return GetFileResponse( + id=file_metadata.id, + filename=file_metadata.filename, + size=file_metadata.size, + ) + + +class ListFilesUseCase: + def __init__(self, file_storage_gateway: FileStorageGateway): + self.file_storage_gateway = file_storage_gateway + + async def execute(self, user: User) -> ListFilesResponse: + files = await self.file_storage_gateway.list_files( + owner=user.team_id, + ) + return ListFilesResponse( + files=[ + GetFileResponse( + id=file_metadata.id, + filename=file_metadata.filename, + size=file_metadata.size, + ) + for file_metadata in files + ] + ) + + +class DeleteFileUseCase: + def __init__(self, file_storage_gateway: FileStorageGateway): + self.file_storage_gateway = file_storage_gateway + + async def execute(self, user: User, file_id: str) -> DeleteFileResponse: + deleted = await self.file_storage_gateway.delete_file( + owner=user.team_id, + file_id=file_id, + ) + return DeleteFileResponse( + deleted=deleted, + ) + + +class GetFileContentUseCase: + def __init__(self, file_storage_gateway: FileStorageGateway): + self.file_storage_gateway = file_storage_gateway + + async def execute(self, user: User, file_id: str) -> GetFileContentResponse: + content = await self.file_storage_gateway.get_file_content( + owner=user.team_id, + file_id=file_id, + ) + if content is None: + raise ObjectNotFoundException + return GetFileContentResponse( + id=file_id, + content=content, + ) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py new file mode 100644 index 00000000..02466a52 --- /dev/null +++ b/model-engine/model_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py @@ -0,0 +1,274 @@ +import csv +import datetime +import json +import re +from typing import Optional + +import smart_open +from model_engine_server.common.dtos.llms import ( + CancelFineTuneResponse, + CreateFineTuneRequest, + CreateFineTuneResponse, + GetFineTuneEventsResponse, + GetFineTuneResponse, + ListFineTunesResponse, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.entities import BatchJobStatus +from model_engine_server.domain.exceptions import ( + InvalidRequestException, + LLMFineTuningQuotaReached, + ObjectNotFoundException, +) +from model_engine_server.domain.gateways import FileStorageGateway +from model_engine_server.domain.repositories import LLMFineTuneEventsRepository +from model_engine_server.domain.services import LLMFineTuningService, ModelEndpointService + +DEFAULT_FINE_TUNING_METHOD = "lora" +REQUIRED_COLUMNS = [["prompt", "response"], ["input", "output"]] + +MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER = 5 + +MAX_SUFFIX_LENGTH = 28 +# k8s labels need to be <= 62 characters, timestamp takes 13 characters, 2 characters for periods, +# model name is currently 17 long, but want to add a bit of buffer. + +logger = make_logger(logger_name()) + + +def is_model_name_suffix_valid(model_name: str): + pattern = "^[A-Za-z0-9-]+$" # TODO can we do spaces and underscores + return bool(re.match(pattern, model_name)) and len(model_name) <= MAX_SUFFIX_LENGTH + + +def ensure_model_name_is_valid_k8s_label(model_name: str): + """ + Ensure the model name is usable as a k8s label, + since we will end up creating a deployment with the model name as a label. + """ + return re.sub("[^-A-Za-z0-9_.]", "-", model_name).lstrip("-_.")[:62].rstrip("-_.") + + +def read_csv_headers(file_location: str): + """ + Read the headers of a csv file. + This will also parse for a JSONL file and will return the first row of the file split by comma. + """ + with smart_open.open(file_location, transport_params=dict(buffer_size=1024)) as file: + csv_reader = csv.DictReader(file) + return csv_reader.fieldnames + + +def are_dataset_headers_valid(file_location: str): + """ + Ensure the dataset headers are valid with required columns 'prompt' and 'response'. + """ + current_headers = read_csv_headers(file_location) + first_line = ",".join(current_headers) + try: + object = json.loads(first_line) # JSONL file format + current_headers = object.keys() + except json.decoder.JSONDecodeError: # CSV file format + pass + return any( + [ + all(header in current_headers for header in header_group) + for header_group in REQUIRED_COLUMNS + ] + ) + + +def check_file_is_valid(file_name: Optional[str], file_type: str): + """ + Ensure the file is valid with required columns 'prompt' and 'response', isn't malformatted, and exists. + Accepts CSV and JSONL formats. + file_type: 'training' or 'validation' + """ + try: + if file_name is not None and not are_dataset_headers_valid(file_name): + raise InvalidRequestException( + f"Required column headers (one subset of {REQUIRED_COLUMNS}) not found in {file_type} dataset" + ) + except FileNotFoundError: + raise InvalidRequestException( + f"Cannot find the {file_type} file. Verify the path and file name are correct." + ) + except csv.Error as exc: + raise InvalidRequestException( + f"Cannot parse the {file_type} dataset as CSV. Details: {exc}" + ) + + +class CreateFineTuneV1UseCase: + def __init__( + self, + llm_fine_tuning_service: LLMFineTuningService, + model_endpoint_service: ModelEndpointService, + llm_fine_tune_events_repository: LLMFineTuneEventsRepository, + file_storage_gateway: FileStorageGateway, + ): + self.llm_fine_tuning_service = llm_fine_tuning_service + self.model_endpoint_service = model_endpoint_service + self.llm_fine_tune_events_repository = llm_fine_tune_events_repository + self.file_storage_gateway = file_storage_gateway + + async def execute(self, user: User, request: CreateFineTuneRequest) -> CreateFineTuneResponse: + di_batch_jobs = await self.llm_fine_tuning_service.list_fine_tunes( + owner=user.team_id, + ) + in_progress_jobs = [ + job + for job in di_batch_jobs + if job.status in [BatchJobStatus.PENDING, BatchJobStatus.RUNNING] + ] + model_endpoints = await self.model_endpoint_service.list_model_endpoints( + owner=user.team_id, name=None, order_by=None + ) + + current_jobs_and_endpoints = len(in_progress_jobs) + len(model_endpoints) + + if ( + not user.is_privileged_user + and current_jobs_and_endpoints >= MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER + ): + raise LLMFineTuningQuotaReached( + f"Limit {MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER} fine-tunes/fine-tuned endpoints per user. " + f"Cancel/delete a total of " + f"{current_jobs_and_endpoints - MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER + 1} pending or " + f"running fine-tune(s) or fine-tuned endpoints to run another fine-tune." + ) + + if request.suffix is not None and not is_model_name_suffix_valid(request.suffix): + raise InvalidRequestException( + f"User-provided suffix is invalid, must only contain alphanumeric characters and dashes and be at most {MAX_SUFFIX_LENGTH} characters" + ) + time_now = datetime.datetime.utcnow().strftime("%y%m%d-%H%M%S") + # Colons breaks our download command. Keep delimiters as `.` + fine_tuned_model = ( + f"{request.model}.{request.suffix}.{time_now}" + if request.suffix is not None + else f"{request.model}.{time_now}" + ) + + # We need to ensure fine_tuned_model conforms to the k8s label spec + # This is unfortunately a leaky abstraction. This likely goes away if we redo how we implement fine-tuned + # models though + fine_tuned_model = ensure_model_name_is_valid_k8s_label(fine_tuned_model) + + if request.training_file.startswith("file-"): + training_file_url = await self.file_storage_gateway.get_url_from_id( + user.team_id, request.training_file + ) + if training_file_url is None: + raise ObjectNotFoundException("Training file does not exist") + else: + training_file_url = request.training_file + + if request.validation_file is not None and request.validation_file.startswith("file-"): + validation_file_url = await self.file_storage_gateway.get_url_from_id( + user.team_id, request.validation_file + ) + if validation_file_url is None: + raise ObjectNotFoundException("Validation file does not exist") + else: + validation_file_url = request.validation_file + + check_file_is_valid(training_file_url, "training") + check_file_is_valid(validation_file_url, "validation") + + await self.llm_fine_tune_events_repository.initialize_events(user.team_id, fine_tuned_model) + fine_tune_id = await self.llm_fine_tuning_service.create_fine_tune( + created_by=user.user_id, + owner=user.team_id, + model=request.model, + training_file=request.training_file, # for Files API, pass file ID rather than signed URL since the latter expires; fine-tuning script will get file content from Files API + validation_file=request.validation_file, + fine_tuning_method=DEFAULT_FINE_TUNING_METHOD, + hyperparameters=request.hyperparameters, + fine_tuned_model=fine_tuned_model, + wandb_config=request.wandb_config, + ) + return CreateFineTuneResponse( + id=fine_tune_id, + ) + + +class GetFineTuneV1UseCase: + def __init__(self, llm_fine_tuning_service: LLMFineTuningService): + self.llm_fine_tuning_service = llm_fine_tuning_service + + async def execute(self, user: User, fine_tune_id: str) -> GetFineTuneResponse: + di_batch_job = await self.llm_fine_tuning_service.get_fine_tune( + owner=user.team_id, + fine_tune_id=fine_tune_id, + ) + if di_batch_job is None: + raise ObjectNotFoundException + if di_batch_job.annotations: + fine_tuned_model = di_batch_job.annotations.get("fine_tuned_model") + else: + fine_tuned_model = None + logger.warning(f"Fine-tune {di_batch_job.id} has no annotations. This is unexpected.") + return GetFineTuneResponse( + id=di_batch_job.id, + fine_tuned_model=fine_tuned_model, + status=di_batch_job.status, + ) + + +class ListFineTunesV1UseCase: + def __init__(self, llm_fine_tuning_service: LLMFineTuningService): + self.llm_fine_tuning_service = llm_fine_tuning_service + + async def execute(self, user: User) -> ListFineTunesResponse: + di_batch_jobs = await self.llm_fine_tuning_service.list_fine_tunes( + owner=user.team_id, + ) + return ListFineTunesResponse( + jobs=[ + GetFineTuneResponse( + id=job.id, + status=job.status, + fine_tuned_model=( + job.annotations.get("fine_tuned_model") if job.annotations else None + ), + ) + for job in di_batch_jobs + ] + ) + + +class CancelFineTuneV1UseCase: + def __init__(self, llm_fine_tuning_service: LLMFineTuningService): + self.llm_fine_tuning_service = llm_fine_tuning_service + + async def execute(self, user: User, fine_tune_id: str) -> CancelFineTuneResponse: + success = await self.llm_fine_tuning_service.cancel_fine_tune( + owner=user.team_id, + fine_tune_id=fine_tune_id, + ) + return CancelFineTuneResponse( + success=success, + ) + + +class GetFineTuneEventsV1UseCase: + def __init__( + self, + llm_fine_tune_events_repository: LLMFineTuneEventsRepository, + llm_fine_tuning_service: LLMFineTuningService, + ): + self.llm_fine_tune_events_repository = llm_fine_tune_events_repository + self.llm_fine_tuning_service = llm_fine_tuning_service + + async def execute(self, user: User, fine_tune_id: str) -> GetFineTuneEventsResponse: + model_endpoint_name = await self.llm_fine_tuning_service.get_fine_tune_model_name_from_id( + user.team_id, fine_tune_id + ) + if model_endpoint_name is None: + raise ObjectNotFoundException(f"Fine-tune with id {fine_tune_id} not found") + events = await self.llm_fine_tune_events_repository.get_fine_tune_events( + user_id=user.team_id, model_endpoint_name=model_endpoint_name + ) + return GetFineTuneEventsResponse(events=events) diff --git a/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py new file mode 100644 index 00000000..e2069b39 --- /dev/null +++ b/model-engine/model_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py @@ -0,0 +1,3722 @@ +""" +TODO figure out how to do: (or if we want to do it) +List model endpoint history: GET model-endpoints//history +Read model endpoint creation logs: GET model-endpoints//creation-logs +""" + +import base64 +import datetime +import json +import math +import os +import re +from dataclasses import asdict +from functools import lru_cache +from typing import Any, AsyncGenerator, AsyncIterable, Dict, List, Optional, Union + +import yaml +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.llms import ( + ChatCompletionV2Request, + ChatCompletionV2StreamSuccessChunk, + ChatCompletionV2SyncResponse, + CompletionOutput, + CompletionStreamOutput, + CompletionStreamV1Request, + CompletionStreamV1Response, + CompletionSyncV1Request, + CompletionSyncV1Response, + CreateBatchCompletionsEngineRequest, + CreateBatchCompletionsV1Request, + CreateBatchCompletionsV1Response, + CreateBatchCompletionsV2Request, + CreateBatchCompletionsV2Response, + CreateLLMModelEndpointV1Request, + CreateLLMModelEndpointV1Response, + DeleteLLMEndpointResponse, + GetLLMModelEndpointV1Response, + ListLLMModelEndpointsV1Response, + ModelDownloadRequest, + ModelDownloadResponse, + TokenOutput, + UpdateLLMModelEndpointV1Request, + UpdateLLMModelEndpointV1Response, +) +from model_engine_server.common.dtos.llms.batch_completion import ( + CancelBatchCompletionsV2Response, + GetBatchCompletionV2Response, + UpdateBatchCompletionsV2Request, + UpdateBatchCompletionsV2Response, +) +from model_engine_server.common.dtos.llms.completion import ( + CompletionV2Request, + CompletionV2StreamSuccessChunk, + CompletionV2SyncResponse, +) +from model_engine_server.common.dtos.llms.vllm import VLLMEndpointAdditionalArgs, VLLMModelConfig +from model_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Request, TaskStatus +from model_engine_server.common.resource_limits import validate_resource_requests +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.configmap import read_config_map +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) +from model_engine_server.domain.entities import ( + GpuType, + LLMInferenceFramework, + LLMMetadata, + LLMSource, + ModelBundle, + ModelBundleFlavorType, + ModelEndpoint, + ModelEndpointType, + Quantization, + RunnableImageFlavor, + RunnableImageLike, + StreamingEnhancedRunnableImageFlavor, +) +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( + DockerImageBatchJobBundle, +) +from model_engine_server.domain.exceptions import ( + DockerImageNotFoundException, + EndpointInfraStateNotFound, + EndpointLabelsException, + EndpointUnsupportedInferenceTypeException, + EndpointUnsupportedRequestException, + FailToInferHardwareException, + InvalidRequestException, + LatestImageTagNotFoundException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, + UpstreamServiceError, +) +from model_engine_server.domain.gateways import ( + DockerImageBatchJobGateway, + StreamingModelEndpointInferenceGateway, +) +from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway +from model_engine_server.domain.repositories import ( + DockerImageBatchJobBundleRepository, + DockerRepository, + ModelBundleRepository, + TokenizerRepository, +) +from model_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService +from model_engine_server.domain.services.llm_batch_completions_service import ( + LLMBatchCompletionsService, +) +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +from model_engine_server.infra.repositories.live_tokenizer_repository import ( + SUPPORTED_MODELS_INFO, + get_models_s3_uri, +) + +from ...common.datadog_utils import add_trace_model_name, add_trace_request_id +from ..authorization.live_authorization_module import LiveAuthorizationModule +from .model_bundle_use_cases import CreateModelBundleV2UseCase +from .model_endpoint_use_cases import ( + CONVERTED_FROM_ARTIFACT_LIKE_KEY, + _handle_post_inference_hooks, + model_endpoint_entity_to_get_model_endpoint_response, + validate_billing_tags, + validate_deployment_resources, + validate_labels, + validate_post_inference_hooks, +) + +logger = make_logger(logger_name()) + +OPENAI_CHAT_COMPLETION_PATH = "/v1/chat/completions" +CHAT_TEMPLATE_MAX_LENGTH = 10_000 +CHAT_SUPPORTED_INFERENCE_FRAMEWORKS = [LLMInferenceFramework.VLLM] + +OPENAI_COMPLETION_PATH = "/v1/completions" +OPENAI_SUPPORTED_INFERENCE_FRAMEWORKS = [LLMInferenceFramework.VLLM] + +LLM_METADATA_KEY = "_llm" +RESERVED_METADATA_KEYS = [LLM_METADATA_KEY, CONVERTED_FROM_ARTIFACT_LIKE_KEY] +VLLM_MODEL_WEIGHTS_FOLDER = "model_files" + +INFERENCE_FRAMEWORK_REPOSITORY: Dict[LLMInferenceFramework, str] = { + LLMInferenceFramework.DEEPSPEED: "instant-llm", + LLMInferenceFramework.TEXT_GENERATION_INFERENCE: hmi_config.tgi_repository, + LLMInferenceFramework.VLLM: hmi_config.vllm_repository, + LLMInferenceFramework.LIGHTLLM: hmi_config.lightllm_repository, + LLMInferenceFramework.TENSORRT_LLM: hmi_config.tensorrt_llm_repository, +} + +_SUPPORTED_MODELS_BY_FRAMEWORK = { + LLMInferenceFramework.DEEPSPEED: set( + [ + "mpt-7b", + "mpt-7b-instruct", + "flan-t5-xxl", + "llama-7b", + "gpt-j-6b", + "gpt-j-6b-zh-en", + "gpt4all-j", + "dolly-v2-12b", + "stablelm-tuned-7b", + "vicuna-13b", + ] + ), + LLMInferenceFramework.TEXT_GENERATION_INFERENCE: set( + [ + "mpt-7b", + "mpt-7b-instruct", + "flan-t5-xxl", + "llama-7b", + "llama-2-7b", + "llama-2-7b-chat", + "llama-2-13b", + "llama-2-13b-chat", + "llama-2-70b", + "llama-2-70b-chat", + "falcon-7b", + "falcon-7b-instruct", + "falcon-40b", + "falcon-40b-instruct", + "codellama-7b", + "codellama-7b-instruct", + "codellama-13b", + "codellama-13b-instruct", + "codellama-34b", + "codellama-34b-instruct", + "llm-jp-13b-instruct-full", + "llm-jp-13b-instruct-full-dolly", + "zephyr-7b-alpha", + "zephyr-7b-beta", + ] + ), + LLMInferenceFramework.VLLM: set( + [ + "mpt-7b", + "mpt-7b-instruct", + "llama-7b", + "llama-2-7b", + "llama-2-7b-chat", + "llama-2-13b", + "llama-2-13b-chat", + "llama-2-70b", + "llama-2-70b-chat", + "llama-3-8b", + "llama-3-8b-instruct", + "llama-3-8b-instruct-262k", + "llama-3-70b", + "llama-3-70b-instruct", + "llama-3-1-8b", + "llama-3-1-8b-instruct", + "llama-3-1-70b", + "llama-3-1-70b-instruct", + "llama-3-1-405b", + "llama-3-1-405b-instruct", + "llama-3-2-1b-instruct", + "llama-3-2-3b-instruct", + "llama-3-2-11b-vision-instruct", + "llama-3-2-90b-vision-instruct", + "falcon-7b", + "falcon-7b-instruct", + "falcon-40b", + "falcon-40b-instruct", + "falcon-180b", + "falcon-180b-chat", + "codellama-7b", + "codellama-7b-instruct", + "codellama-13b", + "codellama-13b-instruct", + "codellama-34b", + "codellama-34b-instruct", + "codellama-70b", + "codellama-70b-instruct", + "mistral-7b", + "mistral-7b-instruct", + "mixtral-8x7b", + "mixtral-8x7b-instruct", + "mixtral-8x22b", + "mixtral-8x22b-instruct", + "mammoth-coder-llama-2-7b", + "mammoth-coder-llama-2-13b", + "mammoth-coder-llama-2-34b", + "zephyr-7b-alpha", + "zephyr-7b-beta", + "gemma-2b", + "gemma-2b-instruct", + "gemma-7b", + "gemma-7b-instruct", + "phi-3-mini-4k-instruct", + "phi-3-mini-128k-instruct", + "phi-3-small-8k-instruct", + "phi-3-small-128k-instruct", + "phi-3-medium-4-instruct", + "phi-3-medium-128k-instruct", + "deepseek-v2", + "deepseek-v2-chat", + "deepseek-coder-v2", + "deepseek-coder-v2-instruct", + "deepseek-coder-v2-lite", + "deepseek-coder-v2-lite-instruct", + "qwen2-72b-instruct", + ] + ), + LLMInferenceFramework.LIGHTLLM: set( + [ + "llama-7b", + "llama-2-7b", + "llama-2-7b-chat", + "llama-2-13b", + "llama-2-13b-chat", + "llama-2-70b", + "llama-2-70b-chat", + ] + ), + LLMInferenceFramework.TENSORRT_LLM: set( + ["llama-2-7b", "mixtral-8x7b", "mixtral-8x7b-instruct"] + ), +} + +_SUPPORTED_QUANTIZATIONS: Dict[LLMInferenceFramework, List[Quantization]] = { + LLMInferenceFramework.DEEPSPEED: [], + LLMInferenceFramework.TEXT_GENERATION_INFERENCE: [Quantization.BITSANDBYTES], + LLMInferenceFramework.VLLM: [Quantization.AWQ], + LLMInferenceFramework.LIGHTLLM: [], + LLMInferenceFramework.TENSORRT_LLM: [], +} + + +NUM_DOWNSTREAM_REQUEST_RETRIES = 80 # has to be high enough so that the retries take the 5 minutes +DOWNSTREAM_REQUEST_TIMEOUT_SECONDS = 5 * 60 # 5 minutes + +SERVICE_NAME = "model-engine" +SERVICE_IDENTIFIER = os.getenv("SERVICE_IDENTIFIER") +LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME = f"{SERVICE_NAME}-inference-framework-latest-config" +RECOMMENDED_HARDWARE_CONFIG_MAP_NAME = f"{SERVICE_NAME}-recommended-hardware-config" +if SERVICE_IDENTIFIER: + SERVICE_NAME += f"-{SERVICE_IDENTIFIER}" + + +def count_tokens(input: str, model_name: str, tokenizer_repository: TokenizerRepository) -> int: + """ + Count the number of tokens in the input string. + """ + tokenizer = tokenizer_repository.load_tokenizer(model_name) + return len(tokenizer.encode(input)) + + +async def _get_latest_batch_v2_tag(inference_framework: LLMInferenceFramework) -> str: + config_map = await read_config_map(LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME) + batch_key = f"{inference_framework}_batch_v2" + if batch_key not in config_map: + raise LatestImageTagNotFoundException( + f"Could not find latest batch job tag for inference framework {inference_framework}. key: {batch_key}" + ) + return config_map[batch_key] + + +async def _get_latest_batch_tag(inference_framework: LLMInferenceFramework) -> str: + config_map = await read_config_map(LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME) + batch_key = f"{inference_framework}_batch" + if batch_key not in config_map: + raise LatestImageTagNotFoundException( + f"Could not find latest batch job tag for inference framework {inference_framework}. key: {batch_key}" + ) + return config_map[batch_key] + + +async def _get_latest_tag(inference_framework: LLMInferenceFramework) -> str: + config_map = await read_config_map(LATEST_INFERENCE_FRAMEWORK_CONFIG_MAP_NAME) + if inference_framework not in config_map: + raise LatestImageTagNotFoundException( + f"Could not find latest tag for inference framework {inference_framework}." + ) + return config_map[inference_framework] + + +async def _get_recommended_hardware_config_map() -> Dict[str, Any]: + try: + config_map = await read_config_map(RECOMMENDED_HARDWARE_CONFIG_MAP_NAME) + except Exception as e: + logger.error( + f"Failed to read config map {RECOMMENDED_HARDWARE_CONFIG_MAP_NAME}, can't infer hardware config." + ) + raise FailToInferHardwareException( + f"Failed to read config map {RECOMMENDED_HARDWARE_CONFIG_MAP_NAME}, can't infer hardware config." + ) from e + return config_map + + +def _model_endpoint_entity_to_get_llm_model_endpoint_response( + model_endpoint: ModelEndpoint, +) -> GetLLMModelEndpointV1Response: + if ( + model_endpoint.record.metadata is None + or LLM_METADATA_KEY not in model_endpoint.record.metadata + ): + raise ObjectHasInvalidValueException( + f"Can't translate model entity to response, endpoint {model_endpoint.record.id} does not have LLM metadata." + ) + llm_metadata = model_endpoint.record.metadata.get(LLM_METADATA_KEY, {}) + response = GetLLMModelEndpointV1Response( + id=model_endpoint.record.id, + name=model_endpoint.record.name, + model_name=llm_metadata["model_name"], + source=llm_metadata["source"], + status=model_endpoint.record.status, + inference_framework=llm_metadata["inference_framework"], + inference_framework_image_tag=llm_metadata["inference_framework_image_tag"], + num_shards=llm_metadata["num_shards"], + quantize=llm_metadata.get("quantize"), + checkpoint_path=llm_metadata.get("checkpoint_path"), + chat_template_override=llm_metadata.get("chat_template_override"), + spec=model_endpoint_entity_to_get_model_endpoint_response(model_endpoint), + ) + return response + + +def validate_model_name(model_name: str, inference_framework: LLMInferenceFramework) -> None: + # TODO: replace this logic to check if the model architecture is supported instead + if model_name not in _SUPPORTED_MODELS_BY_FRAMEWORK[inference_framework]: + logger.warning( + f"Model name {model_name} may not be supported by inference framework {inference_framework}." + ) + + +def validate_num_shards( + num_shards: int, inference_framework: LLMInferenceFramework, gpus: int +) -> None: + if inference_framework == LLMInferenceFramework.DEEPSPEED: + if num_shards <= 1: + raise ObjectHasInvalidValueException("DeepSpeed requires more than 1 GPU.") + if num_shards != gpus: + raise ObjectHasInvalidValueException( + f"Num shard {num_shards} must be the same as number of GPUs {gpus} for DeepSpeed." + ) + if num_shards != gpus: + raise ObjectHasInvalidValueException( + f"Num shard {num_shards} must be equal to the number of GPUs {gpus}." + ) + + +def validate_quantization( + quantize: Optional[Quantization], inference_framework: LLMInferenceFramework +) -> None: + if quantize is not None and quantize not in _SUPPORTED_QUANTIZATIONS[inference_framework]: + raise ObjectHasInvalidValueException( + f"Quantization {quantize} is not supported for inference framework {inference_framework}. Supported quantization types are {_SUPPORTED_QUANTIZATIONS[inference_framework]}." + ) + + +def validate_chat_template( + chat_template: Optional[str], inference_framework: LLMInferenceFramework +) -> None: + if chat_template is not None: + if len(chat_template) > CHAT_TEMPLATE_MAX_LENGTH: + raise ObjectHasInvalidValueException( + f"Chat template length must be less than {CHAT_TEMPLATE_MAX_LENGTH}." + ) + + if inference_framework != LLMInferenceFramework.VLLM: + raise ObjectHasInvalidValueException( + f"Chat template is only supported for inference framework {LLMInferenceFramework.VLLM}." + ) + + +def validate_checkpoint_path_uri(checkpoint_path: str) -> None: + if ( + not checkpoint_path.startswith("s3://") + and not checkpoint_path.startswith("azure://") + and "blob.core.windows.net" not in checkpoint_path + ): + raise ObjectHasInvalidValueException( + f"Only S3 and Azure Blob Storage paths are supported. Given checkpoint path: {checkpoint_path}." + ) + if checkpoint_path.endswith(".tar"): + raise ObjectHasInvalidValueException( + f"Tar files are not supported. Given checkpoint path: {checkpoint_path}." + ) + + +def get_checkpoint_path(model_name: str, checkpoint_path_override: Optional[str]) -> str: + checkpoint_path = None + models_info = SUPPORTED_MODELS_INFO.get(model_name, None) + if checkpoint_path_override: + checkpoint_path = checkpoint_path_override + elif models_info and models_info.s3_repo: + checkpoint_path = get_models_s3_uri(models_info.s3_repo, "") # pragma: no cover + + if not checkpoint_path: + raise InvalidRequestException(f"No checkpoint path found for model {model_name}") + + validate_checkpoint_path_uri(checkpoint_path) + return checkpoint_path + + +def validate_checkpoint_files(checkpoint_files: List[str]) -> None: + """Require safetensors in the checkpoint path.""" + model_files = [f for f in checkpoint_files if "model" in f] + num_safetensors = len([f for f in model_files if f.endswith(".safetensors")]) + if num_safetensors == 0: + raise ObjectHasInvalidValueException("No safetensors found in the checkpoint path.") + + +def encode_template(chat_template: str) -> str: + """Base64 encode the chat template to safely pass it to bash.""" + + encoded = base64.b64encode(chat_template.encode("utf-8")).decode("utf-8") + return encoded + + +class CreateLLMModelBundleV1UseCase: + def __init__( + self, + create_model_bundle_use_case: CreateModelBundleV2UseCase, + model_bundle_repository: ModelBundleRepository, + llm_artifact_gateway: LLMArtifactGateway, + docker_repository: DockerRepository, + ): + self.authz_module = LiveAuthorizationModule() + self.create_model_bundle_use_case = create_model_bundle_use_case + self.model_bundle_repository = model_bundle_repository + self.llm_artifact_gateway = llm_artifact_gateway + self.docker_repository = docker_repository + + def check_docker_image_exists_for_image_tag( + self, framework_image_tag: str, repository_name: str + ): + if not self.docker_repository.image_exists( + image_tag=framework_image_tag, + repository_name=repository_name, + ): + raise DockerImageNotFoundException( + repository=repository_name, + tag=framework_image_tag, + ) + + async def execute( + self, + user: User, + endpoint_name: str, + model_name: str, + source: LLMSource, + framework: LLMInferenceFramework, + framework_image_tag: str, + endpoint_type: ModelEndpointType, + num_shards: int, + quantize: Optional[Quantization], + checkpoint_path: Optional[str], + chat_template_override: Optional[str], + nodes_per_worker: int, + additional_args: Optional[Dict[str, Any]] = None, + ) -> ModelBundle: + multinode = nodes_per_worker > 1 + if source == LLMSource.HUGGING_FACE: + self.check_docker_image_exists_for_image_tag( + framework_image_tag, INFERENCE_FRAMEWORK_REPOSITORY[framework] + ) + if multinode and framework != LLMInferenceFramework.VLLM: + raise ObjectHasInvalidValueException( + f"Multinode is not supported for framework {framework}." + ) + + if framework == LLMInferenceFramework.DEEPSPEED: + bundle_id = await self.create_deepspeed_bundle( + user, + model_name, + framework_image_tag, + endpoint_type, + endpoint_name, + ) + elif framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: + bundle_id = await self.create_text_generation_inference_bundle( + user, + model_name, + framework_image_tag, + endpoint_name, + num_shards, + quantize, + checkpoint_path, + ) + elif framework == LLMInferenceFramework.VLLM: + additional_vllm_args = ( + VLLMEndpointAdditionalArgs.model_validate(additional_args) + if additional_args + else None + ) + if multinode: + bundle_id = await self.create_vllm_multinode_bundle( + user, + model_name, + framework_image_tag, + endpoint_name, + num_shards, + nodes_per_worker, + quantize, + checkpoint_path, + chat_template_override, + additional_args=additional_vllm_args, + ) + else: + bundle_id = await self.create_vllm_bundle( + user, + model_name, + framework_image_tag, + endpoint_name, + num_shards, + quantize, + checkpoint_path, + chat_template_override, + additional_args=additional_vllm_args, + ) + elif framework == LLMInferenceFramework.LIGHTLLM: + bundle_id = await self.create_lightllm_bundle( + user, + model_name, + framework_image_tag, + endpoint_name, + num_shards, + checkpoint_path, + ) + elif framework == LLMInferenceFramework.TENSORRT_LLM: + bundle_id = await self.create_tensorrt_llm_bundle( + user, + framework_image_tag, + endpoint_name, + num_shards, + checkpoint_path, + ) + else: + raise ObjectHasInvalidValueException( + f"Framework {framework} is not supported for source {source}." + ) + else: + raise ObjectHasInvalidValueException(f"Source {source} is not supported.") + + model_bundle = await self.model_bundle_repository.get_model_bundle(bundle_id) + if model_bundle is None: + raise ObjectNotFoundException(f"Model bundle {bundle_id} was not found after creation.") + return model_bundle + + async def create_text_generation_inference_bundle( + self, + user: User, + model_name: str, + framework_image_tag: str, + endpoint_unique_name: str, + num_shards: int, + quantize: Optional[Quantization], + checkpoint_path: Optional[str], + ): + command = [] + + # TGI requires max_input_length < max_total_tokens + max_input_length = 1024 + max_total_tokens = 2048 + if "llama-2" in model_name: + max_input_length = 4095 + max_total_tokens = 4096 + + subcommands = [] + + checkpoint_path = get_checkpoint_path(model_name, checkpoint_path) + final_weights_folder = "model_files" + + subcommands += self.load_model_weights_sub_commands( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + framework_image_tag, + checkpoint_path, + final_weights_folder, + ) + + subcommands.append( + f"text-generation-launcher --hostname :: --model-id {final_weights_folder} --num-shard {num_shards} --port 5005 --max-input-length {max_input_length} --max-total-tokens {max_total_tokens}" + ) + + if quantize: + subcommands[-1] = subcommands[-1] + f" --quantize {quantize}" + command = [ + "/bin/bash", + "-c", + ";".join(subcommands), + ] + + return ( + await self.create_model_bundle_use_case.execute( + user, + CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, + repository=hmi_config.tgi_repository, + tag=framework_image_tag, + command=command, + streaming_command=command, + protocol="http", + readiness_initial_delay_seconds=10, + healthcheck_route="/health", + predict_route="/generate", + streaming_predict_route="/generate_stream", + env={}, + ), + metadata={}, + ), + do_auth_check=False, + # Skip auth check because llm create endpoint is called as the user itself, + # but the user isn't directly making the action. It should come from the fine tune + # job. + ) + ).model_bundle_id + + def load_model_weights_sub_commands( + self, + framework, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code: bool = False, + ): + if checkpoint_path.startswith("s3://"): + return self.load_model_weights_sub_commands_s3( + framework, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code, + ) + elif checkpoint_path.startswith("azure://") or "blob.core.windows.net" in checkpoint_path: + return self.load_model_weights_sub_commands_abs( + framework, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code, + ) + else: + raise ObjectHasInvalidValueException( + f"Only S3 and Azure Blob Storage paths are supported. Given checkpoint path: {checkpoint_path}." + ) + + def load_model_weights_sub_commands_s3( + self, + framework, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code: bool, + ): + subcommands = [] + s5cmd = "s5cmd" + + # This is a hack for now to skip installing s5cmd for text-generation-inference:0.9.3-launch_s3, + # which has s5cmd binary already baked in. Otherwise, install s5cmd if it's not already available + if ( + framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE + and framework_image_tag != "0.9.3-launch_s3" + ): + subcommands.append(f"{s5cmd} > /dev/null || conda install -c conda-forge -y {s5cmd}") + else: + s5cmd = "./s5cmd" + + checkpoint_files = self.llm_artifact_gateway.list_files(checkpoint_path) + validate_checkpoint_files(checkpoint_files) + + # filter to configs ('*.model' and '*.json') and weights ('*.safetensors') + # For models that are not supported by transformers directly, we need to include '*.py' and '*.bin' + # to load the model. Only set this flag if "trust_remote_code" is set to True + file_selection_str = '--include "*.model" --include "*.json" --include "*.safetensors" --exclude "optimizer*"' + if trust_remote_code: + file_selection_str += ' --include "*.py"' + subcommands.append( + f"{s5cmd} --numworkers 512 cp --concurrency 10 {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + ) + return subcommands + + def load_model_weights_sub_commands_abs( + self, + framework, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code: bool, + ): + subcommands = [] + + subcommands.extend( + [ + "export AZCOPY_AUTO_LOGIN_TYPE=WORKLOAD", + "curl -L https://aka.ms/downloadazcopy-v10-linux | tar --strip-components=1 -C /usr/local/bin --no-same-owner --exclude=*.txt -xzvf - && chmod 755 /usr/local/bin/azcopy", + ] + ) + + base_path = checkpoint_path.split("/")[-1] + if base_path.endswith(".tar"): + # If the checkpoint file is a tar file, extract it into final_weights_folder + subcommands.extend( + [ + f"azcopy copy {checkpoint_path} .", + f"mkdir -p {final_weights_folder}", + f"tar --no-same-owner -xf {base_path} -C {final_weights_folder}", + ] + ) + else: + additional_pattern = ";*.py" if trust_remote_code else "" + file_selection_str = f'--include-pattern "*.model;*.json;*.safetensors{additional_pattern}" --exclude-pattern "optimizer*"' + subcommands.append( + f"azcopy copy --recursive {file_selection_str} {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + ) + + return subcommands + + def load_model_files_sub_commands_trt_llm( + self, + checkpoint_path, + ): + """ + This function generate subcommands to load model files for TensorRT-LLM. + Each model checkpoint is constituted of two folders: `model_weights` which stores the model engine files, + and `model_tokenizer` which stores the model tokenizer files. + See llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt + and llm-engine/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt + """ + if checkpoint_path.startswith("s3://"): + subcommands = [ + f"./s5cmd --numworkers 512 cp --concurrency 50 {os.path.join(checkpoint_path, '*')} ./" + ] + else: + subcommands.extend( + [ + "export AZCOPY_AUTO_LOGIN_TYPE=WORKLOAD", + "curl -L https://aka.ms/downloadazcopy-v10-linux | tar --strip-components=1 -C /usr/local/bin --no-same-owner --exclude=*.txt -xzvf - && chmod 755 /usr/local/bin/azcopy", + f"azcopy copy --recursive {os.path.join(checkpoint_path, '*')} ./", + ] + ) + return subcommands + + async def create_deepspeed_bundle( + self, + user: User, + model_name: str, + framework_image_tag: str, + endpoint_type: ModelEndpointType, + endpoint_unique_name: str, + ): + if endpoint_type == ModelEndpointType.STREAMING: + command = [ + "dumb-init", + "--", + "ddtrace-run", + "run-streamer", + "--http", + "production_threads", + "--concurrency", + "1", + "--config", + "/install/spellbook/inference/service--spellbook_streaming_inference.yaml", + ] + return ( + await self.create_model_bundle_use_case.execute( + user, + CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, + repository="instant-llm", # TODO: let user choose repo + tag=framework_image_tag, + command=command, + streaming_command=command, + env={ + "MODEL_NAME": model_name, + }, + protocol="http", + readiness_initial_delay_seconds=60, + ), + metadata={}, + ), + do_auth_check=False, + ) + ).model_bundle_id + else: + return ( + await self.create_model_bundle_use_case.execute( + user, + CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=RunnableImageFlavor( + flavor=ModelBundleFlavorType.RUNNABLE_IMAGE, + repository="instant-llm", + tag=framework_image_tag, + command=[ + "dumb-init", + "--", + "ddtrace-run", + "run-service", + "--http", + "production_threads", + "--concurrency", + "1", + "--config", + "/install/spellbook/inference/service--spellbook_inference.yaml", + ], + env={ + "MODEL_NAME": model_name, + }, + protocol="http", + readiness_initial_delay_seconds=1800, + ), + metadata={}, + ), + do_auth_check=False, + ) + ).model_bundle_id + + def _create_vllm_bundle_command( + self, + model_name: str, + framework_image_tag: str, + num_shards: int, + quantize: Optional[Quantization], + checkpoint_path: Optional[str], + chat_template_override: Optional[str], + multinode: bool, + is_worker: bool, + nodes_per_worker: int = 1, # only used if multinode + additional_args: Optional[VLLMEndpointAdditionalArgs] = None, + ): + """ + VLLM start command for the single worker, or the leader in a LeaderWorkerSet. + """ + subcommands = [] + + checkpoint_path = get_checkpoint_path(model_name, checkpoint_path) + + # merge additional_args with inferred_additional_args + # We assume user provided additional args takes precedence over inferred args + vllm_args = VLLMEndpointAdditionalArgs.model_validate( + { + **( + infer_addition_engine_args_from_model_name(model_name).model_dump( + exclude_none=True + ) + ), + **(additional_args.model_dump(exclude_none=True) if additional_args else {}), + } + ) + + # added as workaround since transformers doesn't support mistral yet, vllm expects "mistral" in model weights folder + final_weights_folder = "mistral_files" if "mistral" in model_name else "model_files" + subcommands += self.load_model_weights_sub_commands( + LLMInferenceFramework.VLLM, + framework_image_tag, + checkpoint_path, + final_weights_folder, + trust_remote_code=vllm_args.trust_remote_code or False, + ) + + if multinode: + if not is_worker: + ray_cmd = "/workspace/init_ray.sh leader --ray_cluster_size=$RAY_CLUSTER_SIZE --own_address=$K8S_OWN_POD_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local" + else: + ray_cmd = "/workspace/init_ray.sh worker --ray_address=$LWS_LEADER_ADDRESS.svc.cluster.local --own_address=$K8S_OWN_POD_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local" + subcommands.append(ray_cmd) + + if not is_worker: + vllm_args.tensor_parallel_size = num_shards + + if vllm_args.gpu_memory_utilization is not None: + vllm_args.enforce_eager = True + + if multinode: + vllm_args.pipeline_parallel_size = nodes_per_worker + + if chat_template_override: + vllm_args.chat_template = chat_template_override + + if quantize: + if quantize != Quantization.AWQ: + raise InvalidRequestException( + f"Quantization {quantize} is not supported by vLLM." + ) + + vllm_args.quantization = quantize + + if hmi_config.sensitive_log_mode: + vllm_args.disable_log_requests = True + + vllm_cmd = f"python -m vllm_server --model {final_weights_folder} --served-model-name {model_name} {final_weights_folder} --port 5005" + for field in VLLMEndpointAdditionalArgs.model_fields.keys(): + config_value = getattr(vllm_args, field, None) + if config_value is not None: + # Special handling for chat_template + # Need to encode the chat template as base64 to avoid issues with special characters + if field == "chat_template": + chat_template_cmd = f'export CHAT_TEMPLATE=$(echo "{encode_template(config_value)}" | base64 --decode)' + subcommands.append(chat_template_cmd) + config_value = '"$CHAT_TEMPLATE"' + + # if type of config_value is True, then only need to add the key + if isinstance(config_value, bool): + if config_value: + vllm_cmd += f" --{field.replace('_', '-')}" + else: + vllm_cmd += f" --{field.replace('_', '-')} {config_value}" + + subcommands.append(vllm_cmd) + + command = [ + "/bin/bash", + "-c", + ";".join(subcommands), + ] + + return command + + async def create_vllm_bundle( + self, + user: User, + model_name: str, + framework_image_tag: str, + endpoint_unique_name: str, + num_shards: int, + quantize: Optional[Quantization], + checkpoint_path: Optional[str], + chat_template_override: Optional[str], + additional_args: Optional[VLLMEndpointAdditionalArgs] = None, + ): + command = self._create_vllm_bundle_command( + model_name, + framework_image_tag, + num_shards, + quantize, + checkpoint_path, + chat_template_override, + multinode=False, + is_worker=False, + nodes_per_worker=1, + additional_args=additional_args, + ) + + create_model_bundle_v2_request = CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, + repository=hmi_config.vllm_repository, + tag=framework_image_tag, + command=command, + streaming_command=command, + protocol="http", + readiness_initial_delay_seconds=10, + healthcheck_route="/health", + predict_route="/predict", + streaming_predict_route="/stream", + extra_routes=[ + OPENAI_CHAT_COMPLETION_PATH, + OPENAI_COMPLETION_PATH, + ], + env={}, + ), + metadata={}, + ) + + return ( + await self.create_model_bundle_use_case.execute( + user, + create_model_bundle_v2_request, + do_auth_check=False, + # Skip auth check because llm create endpoint is called as the user itself, + # but the user isn't directly making the action. It should come from the fine tune + # job. + ) + ).model_bundle_id + + async def create_vllm_multinode_bundle( + self, + user: User, + model_name: str, + framework_image_tag: str, + endpoint_unique_name: str, + num_shards: int, + nodes_per_worker: int, + quantize: Optional[Quantization], + checkpoint_path: Optional[str], + chat_template_override: Optional[str], + additional_args: Optional[VLLMEndpointAdditionalArgs] = None, + ): + leader_command = self._create_vllm_bundle_command( + model_name, + framework_image_tag, + num_shards, + quantize, + checkpoint_path, + chat_template_override, + multinode=True, + is_worker=False, + nodes_per_worker=nodes_per_worker, + additional_args=additional_args, + ) + worker_command = self._create_vllm_bundle_command( + model_name, + framework_image_tag, + num_shards, + quantize, + checkpoint_path, + chat_template_override, + multinode=True, + is_worker=True, + nodes_per_worker=nodes_per_worker, + ) + + # These env vars e.g. K8S_OWN_POD_NAME, K8S_OWN_POD_NAME, K8S_OWN_NAMESPACE, K8S_LWS_CLUSTER_SIZE will be filled in automatically for all LWS pods through + # Launch's k8s_endpoint_resource_delegate + common_vllm_envs = { + "VLLM_HOST_IP": "$(K8S_OWN_POD_NAME).$(K8S_LWS_NAME).$(K8S_OWN_NAMESPACE).svc.cluster.local", # this needs to match what's given as --own-address in the vllm start command + "NCCL_SOCKET_IFNAME": "eth0", + "GLOO_SOCKET_IFNAME": "eth0", # maybe don't need + "NCCL_DEBUG": "INFO", # TODO remove once fully tested, will keep around for now + "VLLM_LOGGING_LEVEL": "INFO", # TODO remove once fully tested, will keep around for now + "RAY_CLUSTER_SIZE": "$(K8S_LWS_CLUSTER_SIZE)", + } + + create_model_bundle_v2_request = CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, + repository=hmi_config.vllm_repository, + tag=framework_image_tag, + command=leader_command, + streaming_command=leader_command, + protocol="http", + readiness_initial_delay_seconds=10, + healthcheck_route="/health", + predict_route="/predict", + streaming_predict_route="/stream", + extra_routes=[OPENAI_CHAT_COMPLETION_PATH, OPENAI_COMPLETION_PATH], + env=common_vllm_envs, + worker_command=worker_command, + worker_env=common_vllm_envs, + ), + metadata={}, + ) + + return ( + await self.create_model_bundle_use_case.execute( + user, + create_model_bundle_v2_request, + do_auth_check=False, + # Skip auth check because llm create endpoint is called as the user itself, + # but the user isn't directly making the action. It should come from the fine tune + # job. + ) + ).model_bundle_id + + async def create_lightllm_bundle( + self, + user: User, + model_name: str, + framework_image_tag: str, + endpoint_unique_name: str, + num_shards: int, + checkpoint_path: Optional[str], + ): + command = [] + + # TODO: incorporate auto calculate max_total_token_num from https://github.com/ModelTC/lightllm/pull/81 + max_total_token_num = 6000 # LightLLM default + if num_shards == 1: + max_total_token_num = 15000 # Default for Llama 2 7B on 1 x A10 + elif num_shards == 2: + max_total_token_num = 21000 # Default for Llama 2 13B on 2 x A10 + elif num_shards == 4: + max_total_token_num = 70000 # Default for Llama 2 13B on 4 x A10 + max_req_input_len = 2047 + max_req_total_len = 2048 + if "llama-2" in model_name: + max_req_input_len = 4095 + max_req_total_len = 4096 + + subcommands = [] + + checkpoint_path = get_checkpoint_path(model_name, checkpoint_path) + final_weights_folder = "model_files" + subcommands += self.load_model_weights_sub_commands( + LLMInferenceFramework.LIGHTLLM, + framework_image_tag, + checkpoint_path, + final_weights_folder, + ) + + subcommands.append( + f"python -m lightllm.server.api_server --model_dir {final_weights_folder} --port 5005 --tp {num_shards} --max_total_token_num {max_total_token_num} --max_req_input_len {max_req_input_len} --max_req_total_len {max_req_total_len} --tokenizer_mode auto" + ) + + command = [ + "/bin/bash", + "-c", + ";".join(subcommands), + ] + + return ( + await self.create_model_bundle_use_case.execute( + user, + CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, + repository=hmi_config.lightllm_repository, + tag=framework_image_tag, + command=command, + streaming_command=command, + protocol="http", + readiness_initial_delay_seconds=10, + healthcheck_route="/health", + predict_route="/generate", + streaming_predict_route="/generate_stream", + env={}, + ), + metadata={}, + ), + do_auth_check=False, + # Skip auth check because llm create endpoint is called as the user itself, + # but the user isn't directly making the action. It should come from the fine tune + # job. + ) + ).model_bundle_id + + async def create_tensorrt_llm_bundle( + self, + user: User, + framework_image_tag: str, + endpoint_unique_name: str, + num_shards: int, + checkpoint_path: Optional[str], + ): + command = [] + + subcommands = [] + + if not checkpoint_path: + raise ObjectHasInvalidValueException( + "Checkpoint must be provided for TensorRT-LLM models." + ) + + validate_checkpoint_path_uri(checkpoint_path) + + subcommands += self.load_model_files_sub_commands_trt_llm( + checkpoint_path, + ) + + subcommands.append( + f"python3 launch_triton_server.py --world_size={num_shards} --model_repo=./model_repo/" + ) + + command = [ + "/bin/bash", + "-c", + ";".join(subcommands), + ] + + return ( + await self.create_model_bundle_use_case.execute( + user, + CreateModelBundleV2Request( + name=endpoint_unique_name, + schema_location="TBA", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, + repository=hmi_config.tensorrt_llm_repository, + tag=framework_image_tag, + command=command, + streaming_command=command, + protocol="http", + readiness_initial_delay_seconds=10, + healthcheck_route="/v2/health/ready", + # See https://github.com/triton-inference-server/server/blob/main/docs/protocol/extension_generate.md + predict_route="/v2/models/ensemble/generate", + streaming_predict_route="/v2/models/ensemble/generate_stream", + env={}, + ), + metadata={}, + ), + do_auth_check=False, + # Skip auth check because llm create endpoint is called as the user itself, + # but the user isn't directly making the action. It should come from the fine tune + # job. + ) + ).model_bundle_id + + +class CreateLLMModelEndpointV1UseCase: + def __init__( + self, + create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase, + model_endpoint_service: ModelEndpointService, + docker_repository: DockerRepository, + llm_artifact_gateway: LLMArtifactGateway, + ): + self.authz_module = LiveAuthorizationModule() + self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case + self.model_endpoint_service = model_endpoint_service + self.docker_repository = docker_repository + self.llm_artifact_gateway = llm_artifact_gateway + + async def execute( + self, user: User, request: CreateLLMModelEndpointV1Request + ) -> CreateLLMModelEndpointV1Response: + await _fill_hardware_info(self.llm_artifact_gateway, request) + if not ( + request.gpus + and request.gpu_type + and request.cpus + and request.memory + and request.storage + and request.nodes_per_worker + ): + raise RuntimeError("Some hardware info is missing unexpectedly.") + validate_deployment_resources( + min_workers=request.min_workers, + max_workers=request.max_workers, + endpoint_type=request.endpoint_type, + can_scale_http_endpoint_from_zero=self.model_endpoint_service.can_scale_http_endpoint_from_zero(), + ) + if request.gpu_type == GpuType.NVIDIA_AMPERE_A100E: # pragma: no cover + raise ObjectHasInvalidValueException( + "We have migrated A100 usage to H100. Please request for H100 instead!" + ) + if request.labels is None: + raise EndpointLabelsException("Endpoint labels cannot be None!") + + validate_labels(request.labels) + validate_billing_tags(request.billing_tags) + validate_post_inference_hooks(user, request.post_inference_hooks) + validate_model_name(request.model_name, request.inference_framework) + validate_num_shards(request.num_shards, request.inference_framework, request.gpus) + validate_quantization(request.quantize, request.inference_framework) + validate_chat_template(request.chat_template_override, request.inference_framework) + + if request.inference_framework in [ + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + LLMInferenceFramework.TENSORRT_LLM, + ]: + if request.endpoint_type != ModelEndpointType.STREAMING: + raise ObjectHasInvalidValueException( + f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference, vLLM, LightLLM, and TensorRT-LLM." + ) + + if request.inference_framework_image_tag == "latest": + request.inference_framework_image_tag = await _get_latest_tag( + request.inference_framework + ) + + if ( + request.nodes_per_worker > 1 + and not request.inference_framework == LLMInferenceFramework.VLLM + ): + raise ObjectHasInvalidValueException( + "Multinode endpoints are only supported for VLLM models." + ) + + bundle = await self.create_llm_model_bundle_use_case.execute( + user, + endpoint_name=request.name, + model_name=request.model_name, + source=request.source, + framework=request.inference_framework, + framework_image_tag=request.inference_framework_image_tag, + endpoint_type=request.endpoint_type, + num_shards=request.num_shards, + quantize=request.quantize, + checkpoint_path=request.checkpoint_path, + chat_template_override=request.chat_template_override, + nodes_per_worker=request.nodes_per_worker, + additional_args=request.model_dump(exclude_none=True), + ) + validate_resource_requests( + bundle=bundle, + cpus=request.cpus, + memory=request.memory, + storage=request.storage, + gpus=request.gpus, + gpu_type=request.gpu_type, + ) + + prewarm = request.prewarm + if prewarm is None: + prewarm = True + + high_priority = request.high_priority + if high_priority is None: + high_priority = False + + aws_role = self.authz_module.get_aws_role_for_user(user) + results_s3_bucket = self.authz_module.get_s3_bucket_for_user(user) + + request.metadata[LLM_METADATA_KEY] = asdict( + LLMMetadata( + model_name=request.model_name, + source=request.source, + inference_framework=request.inference_framework, + inference_framework_image_tag=request.inference_framework_image_tag, + num_shards=request.num_shards, + quantize=request.quantize, + checkpoint_path=request.checkpoint_path, + chat_template_override=request.chat_template_override, + ) + ) + + model_endpoint_record = await self.model_endpoint_service.create_model_endpoint( + name=request.name, + created_by=user.user_id, + model_bundle_id=bundle.id, + endpoint_type=request.endpoint_type, + metadata=request.metadata, + post_inference_hooks=request.post_inference_hooks, + child_fn_info=None, + cpus=request.cpus, + gpus=request.gpus, + memory=request.memory, + gpu_type=request.gpu_type, + storage=request.storage, + nodes_per_worker=request.nodes_per_worker, + optimize_costs=bool(request.optimize_costs), + min_workers=request.min_workers, + max_workers=request.max_workers, + per_worker=request.per_worker, + labels=request.labels, + aws_role=aws_role, + results_s3_bucket=results_s3_bucket, + prewarm=prewarm, + high_priority=high_priority, + owner=user.team_id, + default_callback_url=request.default_callback_url, + default_callback_auth=request.default_callback_auth, + public_inference=request.public_inference, + ) + _handle_post_inference_hooks( + created_by=user.user_id, + name=request.name, + post_inference_hooks=request.post_inference_hooks, + ) + + await self.model_endpoint_service.get_inference_autoscaling_metrics_gateway().emit_prewarm_metric( + model_endpoint_record.id + ) + + return CreateLLMModelEndpointV1Response( + endpoint_creation_task_id=model_endpoint_record.creation_task_id # type: ignore + ) + + +class ListLLMModelEndpointsV1UseCase: + """ + Use case for listing all LLM Model Endpoint of a given user and model endpoint name. + Also include public_inference LLM endpoints. + """ + + def __init__(self, llm_model_endpoint_service: LLMModelEndpointService): + self.llm_model_endpoint_service = llm_model_endpoint_service + + async def execute( + self, user: User, name: Optional[str], order_by: Optional[ModelEndpointOrderBy] + ) -> ListLLMModelEndpointsV1Response: + """ + Runs the use case to list all Model Endpoints owned by the user with the given name. + + Args: + user: The owner of the model endpoint(s). + name: The name of the Model Endpoint(s). + order_by: An optional argument to specify the output ordering of the model endpoints. + + Returns: + A response object that contains the model endpoints. + """ + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=name, order_by=order_by + ) + return ListLLMModelEndpointsV1Response( + model_endpoints=[ + _model_endpoint_entity_to_get_llm_model_endpoint_response(m) + for m in model_endpoints + ] + ) + + +class GetLLMModelEndpointByNameV1UseCase: + """ + Use case for getting an LLM Model Endpoint of a given user by name. + """ + + def __init__(self, llm_model_endpoint_service: LLMModelEndpointService): + self.llm_model_endpoint_service = llm_model_endpoint_service + self.authz_module = LiveAuthorizationModule() + + async def execute(self, user: User, model_endpoint_name: str) -> GetLLMModelEndpointV1Response: + """ + Runs the use case to get the LLM endpoint with the given name. + + Args: + user: The owner of the model endpoint. + model_endpoint_name: The name of the model endpoint. + + Returns: + A response object that contains the model endpoint. + + Raises: + ObjectNotFoundException: If a model endpoint with the given name could not be found. + ObjectNotAuthorizedException: If the owner does not own the model endpoint. + """ + model_endpoint = await self.llm_model_endpoint_service.get_llm_model_endpoint( + model_endpoint_name + ) + if not model_endpoint: + raise ObjectNotFoundException + if not self.authz_module.check_access_read_owned_entity( + user, model_endpoint.record + ) and not self.authz_module.check_endpoint_public_inference_for_user( + user, model_endpoint.record + ): + raise ObjectNotAuthorizedException + return _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + + +def merge_metadata( + request: Optional[Dict[str, Any]], record: Optional[Dict[str, Any]] +) -> Optional[Dict[str, Any]]: + if request is None: + return record + if record is None: + return request + return {**record, **request} + + +class UpdateLLMModelEndpointV1UseCase: + def __init__( + self, + create_llm_model_bundle_use_case: CreateLLMModelBundleV1UseCase, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + docker_repository: DockerRepository, + ): + self.authz_module = LiveAuthorizationModule() + self.create_llm_model_bundle_use_case = create_llm_model_bundle_use_case + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + self.docker_repository = docker_repository + + async def execute( + self, + user: User, + model_endpoint_name: str, + request: UpdateLLMModelEndpointV1Request, + ) -> UpdateLLMModelEndpointV1Response: + if request.labels is not None: + validate_labels(request.labels) + validate_billing_tags(request.billing_tags) + validate_post_inference_hooks(user, request.post_inference_hooks) + + model_endpoint = await self.llm_model_endpoint_service.get_llm_model_endpoint( + model_endpoint_name + ) + if not model_endpoint: + raise ObjectNotFoundException + if not self.authz_module.check_access_write_owned_entity(user, model_endpoint.record): + raise ObjectNotAuthorizedException + + endpoint_record = model_endpoint.record + model_endpoint_id = endpoint_record.id + bundle = endpoint_record.current_model_bundle + + # TODO: We may want to consider what happens if an endpoint gets stuck in UPDATE_PENDING + # on first creating it, and we need to find a way to get it unstuck. This would end up + # causing endpoint.infra_state to be None. + if model_endpoint.infra_state is None: + error_msg = f"Endpoint infra state not found for {model_endpoint_name=}" + logger.error(error_msg) + raise EndpointInfraStateNotFound(error_msg) + + infra_state = model_endpoint.infra_state + metadata: Optional[Dict[str, Any]] + + if ( + request.force_bundle_recreation + or request.model_name + or request.source + or request.inference_framework_image_tag + or request.num_shards + or request.quantize + or request.checkpoint_path + or request.chat_template_override + ): + llm_metadata = (model_endpoint.record.metadata or {}).get(LLM_METADATA_KEY, {}) + inference_framework = llm_metadata["inference_framework"] + + if request.inference_framework_image_tag == "latest": + inference_framework_image_tag = await _get_latest_tag(inference_framework) + else: + inference_framework_image_tag = ( + request.inference_framework_image_tag + or llm_metadata["inference_framework_image_tag"] + ) + + model_name = request.model_name or llm_metadata["model_name"] + source = request.source or llm_metadata["source"] + num_shards = request.num_shards or llm_metadata["num_shards"] + quantize = request.quantize or llm_metadata.get("quantize") + checkpoint_path = request.checkpoint_path or llm_metadata.get("checkpoint_path") + + validate_model_name(model_name, inference_framework) + validate_num_shards( + num_shards, + inference_framework, + request.gpus or infra_state.resource_state.gpus, + ) + validate_quantization(quantize, inference_framework) + validate_chat_template(request.chat_template_override, inference_framework) + chat_template_override = request.chat_template_override or llm_metadata.get( + "chat_template_override" + ) + + bundle = await self.create_llm_model_bundle_use_case.execute( + user, + endpoint_name=model_endpoint_name, + model_name=model_name, + source=source, + framework=inference_framework, + framework_image_tag=inference_framework_image_tag, + endpoint_type=endpoint_record.endpoint_type, + num_shards=num_shards, + quantize=quantize, + checkpoint_path=checkpoint_path, + chat_template_override=chat_template_override, + nodes_per_worker=model_endpoint.infra_state.resource_state.nodes_per_worker, + additional_args=request.model_dump(exclude_none=True), + ) + + metadata = endpoint_record.metadata or {} + metadata[LLM_METADATA_KEY] = asdict( + LLMMetadata( + model_name=model_name, + source=source, + inference_framework=inference_framework, + inference_framework_image_tag=inference_framework_image_tag, + num_shards=num_shards, + quantize=quantize, + checkpoint_path=checkpoint_path, + chat_template_override=chat_template_override, + ) + ) + endpoint_record.metadata = metadata + + # For resources that are not specified in the update endpoint request, pass in resource from + # infra_state to make sure that after the update, all resources are valid and in sync. + # E.g. If user only want to update gpus and leave gpu_type as None, we use the existing gpu_type + # from infra_state to avoid passing in None to validate_resource_requests. + validate_resource_requests( + bundle=bundle, + cpus=request.cpus or infra_state.resource_state.cpus, + memory=request.memory or infra_state.resource_state.memory, + storage=request.storage or infra_state.resource_state.storage, + gpus=request.gpus or infra_state.resource_state.gpus, + gpu_type=request.gpu_type or infra_state.resource_state.gpu_type, + ) + + validate_deployment_resources( + min_workers=request.min_workers, + max_workers=request.max_workers, + endpoint_type=endpoint_record.endpoint_type, + can_scale_http_endpoint_from_zero=self.model_endpoint_service.can_scale_http_endpoint_from_zero(), + ) + + if request.metadata is not None: + # If reserved metadata key is provided, throw ObjectHasInvalidValueException + for key in RESERVED_METADATA_KEYS: + if key in request.metadata: + raise ObjectHasInvalidValueException( + f"{key} is a reserved metadata key and cannot be used by user." + ) + + metadata = merge_metadata(request.metadata, endpoint_record.metadata) + + updated_endpoint_record = await self.model_endpoint_service.update_model_endpoint( + model_endpoint_id=model_endpoint_id, + model_bundle_id=bundle.id, + metadata=metadata, + post_inference_hooks=request.post_inference_hooks, + cpus=request.cpus, + gpus=request.gpus, + memory=request.memory, + gpu_type=request.gpu_type, + storage=request.storage, + optimize_costs=request.optimize_costs, + min_workers=request.min_workers, + max_workers=request.max_workers, + per_worker=request.per_worker, + labels=request.labels, + prewarm=request.prewarm, + high_priority=request.high_priority, + default_callback_url=request.default_callback_url, + default_callback_auth=request.default_callback_auth, + public_inference=request.public_inference, + ) + _handle_post_inference_hooks( + created_by=endpoint_record.created_by, + name=updated_endpoint_record.name, + post_inference_hooks=request.post_inference_hooks, + ) + + return UpdateLLMModelEndpointV1Response( + endpoint_creation_task_id=updated_endpoint_record.creation_task_id # type: ignore + ) + + +class DeleteLLMEndpointByNameUseCase: + """ + Use case for deleting an LLM Model Endpoint of a given user by endpoint name. + """ + + def __init__( + self, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + ): + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + self.authz_module = LiveAuthorizationModule() + + async def execute(self, user: User, model_endpoint_name: str) -> DeleteLLMEndpointResponse: + """ + Runs the use case to delete the LLM endpoint owned by the user with the given name. + + Args: + user: The owner of the model endpoint. + model_endpoint_name: The name of the model endpoint. + + Returns: + A response object that contains a boolean indicating if deletion was successful. + + Raises: + ObjectNotFoundException: If a model endpoint with the given name could not be found. + ObjectNotAuthorizedException: If the owner does not own the model endpoint. + """ + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.user_id, name=model_endpoint_name, order_by=None + ) + if len(model_endpoints) != 1: + raise ObjectNotFoundException + model_endpoint = model_endpoints[0] + if not self.authz_module.check_access_write_owned_entity(user, model_endpoint.record): + raise ObjectNotAuthorizedException + await self.model_endpoint_service.delete_model_endpoint(model_endpoint.record.id) + return DeleteLLMEndpointResponse(deleted=True) + + +def deepspeed_result_to_tokens(result: Dict[str, Any]) -> List[TokenOutput]: + tokens = [] + for i in range(len(result["token_probs"]["token_probs"])): + tokens.append( + TokenOutput( + token=result["token_probs"]["tokens"][i], + log_prob=math.log(result["token_probs"]["token_probs"][i]), + ) + ) + return tokens + + +def validate_and_update_completion_params( + inference_framework: LLMInferenceFramework, + request: Union[CompletionSyncV1Request, CompletionStreamV1Request], +) -> Union[CompletionSyncV1Request, CompletionStreamV1Request]: + # top_k, top_p + if inference_framework in [ + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + ]: + if request.temperature == 0: + if request.top_k not in [-1, None] or request.top_p not in [1.0, None]: + raise ObjectHasInvalidValueException( + "top_k and top_p can't be enabled when temperature is 0." + ) + if request.top_k == 0: + raise ObjectHasInvalidValueException( + "top_k needs to be strictly positive, or set it to be -1 / None to disable top_k." + ) + if inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: + request.top_k = None if request.top_k == -1 else request.top_k + request.top_p = None if request.top_p == 1.0 else request.top_p + if inference_framework in [ + LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + ]: + request.top_k = -1 if request.top_k is None else request.top_k + request.top_p = 1.0 if request.top_p is None else request.top_p + else: + if request.top_k or request.top_p: + raise ObjectHasInvalidValueException( + "top_k and top_p are only supported in text-generation-inference, vllm, lightllm." + ) + + # presence_penalty, frequency_penalty + if inference_framework in [ + LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + ]: + request.presence_penalty = ( + 0.0 if request.presence_penalty is None else request.presence_penalty + ) + request.frequency_penalty = ( + 0.0 if request.frequency_penalty is None else request.frequency_penalty + ) + else: + if request.presence_penalty or request.frequency_penalty: + raise ObjectHasInvalidValueException( + "presence_penalty and frequency_penalty are only supported in vllm, lightllm." + ) + + # return_token_log_probs + if inference_framework in [ + LLMInferenceFramework.DEEPSPEED, + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + LLMInferenceFramework.VLLM, + LLMInferenceFramework.LIGHTLLM, + ]: + pass + else: + if request.return_token_log_probs: + raise ObjectHasInvalidValueException( + "return_token_log_probs is only supported in deepspeed, text-generation-inference, vllm, lightllm." + ) + + # include_stop_str_in_output + if inference_framework == LLMInferenceFramework.VLLM: + pass + else: + if request.include_stop_str_in_output is not None: + raise ObjectHasInvalidValueException( + "include_stop_str_in_output is only supported in vllm." + ) + + guided_count = 0 + if request.guided_choice is not None: + guided_count += 1 + if request.guided_json is not None: + guided_count += 1 + if request.guided_regex is not None: + guided_count += 1 + if request.guided_grammar is not None: + guided_count += 1 + + if guided_count > 1: + raise ObjectHasInvalidValueException( + "Only one of guided_json, guided_choice, guided_regex, guided_grammar can be enabled." + ) + + if ( + request.guided_choice is not None + or request.guided_regex is not None + or request.guided_json is not None + or request.guided_grammar is not None + ) and not inference_framework == LLMInferenceFramework.VLLM: + raise ObjectHasInvalidValueException("Guided decoding is only supported in vllm.") + + return request + + +class CompletionSyncV1UseCase: + """ + Use case for running a prompt completion on an LLM endpoint. + """ + + def __init__( + self, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + tokenizer_repository: TokenizerRepository, + ): + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + self.authz_module = LiveAuthorizationModule() + self.tokenizer_repository = tokenizer_repository + + def model_output_to_completion_output( + self, + model_output: Dict[str, Any], + model_endpoint: ModelEndpoint, + prompt: str, + with_token_probs: Optional[bool], + ) -> CompletionOutput: + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: + completion_token_count = len(model_output["token_probs"]["tokens"]) + tokens = None + if with_token_probs: + tokens = deepspeed_result_to_tokens(model_output) + return CompletionOutput( + text=model_output["text"], + num_prompt_tokens=count_tokens( + prompt, + model_content.model_name, + self.tokenizer_repository, + ), + num_completion_tokens=completion_token_count, + tokens=tokens, + ) + elif model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: + try: + tokens = None + if with_token_probs: + tokens = [ + TokenOutput(token=t["text"], log_prob=t["logprob"]) + for t in model_output["details"]["tokens"] + ] + return CompletionOutput( + text=model_output["generated_text"], + num_prompt_tokens=len(model_output["details"]["prefill"]), + num_completion_tokens=model_output["details"]["generated_tokens"], + tokens=tokens, + ) + except Exception: + logger.exception(f"Error parsing text-generation-inference output {model_output}.") + if model_output.get("error_type") == "validation": + raise InvalidRequestException(model_output.get("error")) # trigger a 400 + else: + raise UpstreamServiceError( + status_code=500, content=bytes(model_output["error"], "utf-8") + ) + + elif model_content.inference_framework == LLMInferenceFramework.VLLM: + tokens = None + if with_token_probs: + tokens = [ + TokenOutput( + token=model_output["tokens"][index], + log_prob=list(t.values())[0], + ) + for index, t in enumerate(model_output["log_probs"]) + ] + return CompletionOutput( + text=model_output["text"], + num_prompt_tokens=model_output["count_prompt_tokens"], + num_completion_tokens=model_output["count_output_tokens"], + tokens=tokens, + ) + elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: + tokens = None + if with_token_probs: + tokens = [ + TokenOutput(token=t["text"], log_prob=t["logprob"]) + for t in model_output["tokens"] + ] + return CompletionOutput( + text=model_output["generated_text"][0], + num_prompt_tokens=count_tokens( + prompt, + model_content.model_name, + self.tokenizer_repository, + ), + num_completion_tokens=model_output["count_output_tokens"], + tokens=tokens, + ) + elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: + if not model_content.model_name: + raise InvalidRequestException( + f"Invalid endpoint {model_content.name} has no base model" + ) + if not prompt: + raise InvalidRequestException("Prompt must be provided for TensorRT-LLM models.") + num_prompt_tokens = count_tokens( + prompt, model_content.model_name, self.tokenizer_repository + ) + if "token_ids" in model_output: + # TensorRT 23.10 has this field, TensorRT 24.03 does not + # For backwards compatibility with pre-2024/05/02 + num_completion_tokens = len(model_output["token_ids"]) - num_prompt_tokens + # Output is " prompt output" + text = model_output["text_output"][(len(prompt) + 4) :] + elif "output_log_probs" in model_output: + # TensorRT 24.01 + surrounding code. + # For some reason TRT returns output_log_probs as either a list or a float + # Also the log probs don't look right, so returning log-probs is still broken + num_completion_tokens = ( + len(model_output["output_log_probs"]) + if type(model_output["output_log_probs"]) is list + else 1 + ) + # Output is just "output". See `exclude_input_in_output` inside of + # inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt + text = model_output["text_output"] + return CompletionOutput( + text=text, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) + else: + raise EndpointUnsupportedInferenceTypeException( + f"Unsupported inference framework {model_content.inference_framework}" + ) + + async def execute( + self, user: User, model_endpoint_name: str, request: CompletionSyncV1Request + ) -> CompletionSyncV1Response: + """ + Runs the use case to create a sync inference task. + + Args: + user: The user who is creating the sync inference task. + model_endpoint_name: The name of the model endpoint for the task. + request: The body of the request to forward to the endpoint. + + Returns: + A response object that contains the status and result of the task. + + Raises: + ObjectNotFoundException: If a model endpoint with the given name could not be found. + ObjectNotAuthorizedException: If the owner does not own the model endpoint. + """ + + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + add_trace_request_id(request_id) + + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) + + if len(model_endpoints) == 0: + raise ObjectNotFoundException + + if len(model_endpoints) > 1: + raise ObjectHasInvalidValueException( + f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" + ) + + add_trace_model_name(model_endpoint_name) + + model_endpoint = model_endpoints[0] + + if not self.authz_module.check_access_read_owned_entity( + user, model_endpoint.record + ) and not self.authz_module.check_endpoint_public_inference_for_user( + user, model_endpoint.record + ): + raise ObjectNotAuthorizedException + + if model_endpoint.record.endpoint_type not in [ + ModelEndpointType.SYNC, + ModelEndpointType.STREAMING, + ]: + raise EndpointUnsupportedInferenceTypeException( + f"Endpoint {model_endpoint_name} does not serve sync requests." + ) + + inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint.record.id + ) + endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) + validated_request = validate_and_update_completion_params( + endpoint_content.inference_framework, request + ) + if not isinstance(validated_request, CompletionSyncV1Request): + raise ValueError( + f"request has type {validated_request.__class__.__name__}, expected type CompletionSyncV1Request" + ) + request = validated_request + + if endpoint_content.inference_framework == LLMInferenceFramework.DEEPSPEED: + args: Any = { + "prompts": [request.prompt], + "token_probs": True, + "generate_kwargs": { + "do_sample": True, + "temperature": request.temperature, + "max_new_tokens": request.max_new_tokens, + }, + "serialize_results_as_string": False, + } + if request.stop_sequences is not None: + # Deepspeed models only accepts one stop sequence + args["stop_sequence"] = request.stop_sequences[0] + + inference_request = SyncEndpointPredictV1Request( + args=args, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, + ) + + if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None: + return CompletionSyncV1Response( + request_id=request_id, + output=self.model_output_to_completion_output( + predict_result.result["result"][0], + model_endpoint, + request.prompt, + request.return_token_log_probs, + ), + ) + else: + raise UpstreamServiceError( + status_code=500, + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), + ) + elif ( + endpoint_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE + ): + tgi_args: Any = { + "inputs": request.prompt, + "parameters": { + "max_new_tokens": request.max_new_tokens, + "decoder_input_details": True, + }, + } + if request.stop_sequences is not None: + tgi_args["parameters"]["stop"] = request.stop_sequences + if request.temperature > 0: + tgi_args["parameters"]["temperature"] = request.temperature + tgi_args["parameters"]["do_sample"] = True + tgi_args["parameters"]["top_k"] = request.top_k + tgi_args["parameters"]["top_p"] = request.top_p + else: + tgi_args["parameters"]["do_sample"] = False + + inference_request = SyncEndpointPredictV1Request( + args=tgi_args, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, + ) + + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + raise UpstreamServiceError( + status_code=500, + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), + ) + + output = json.loads(predict_result.result["result"]) + + return CompletionSyncV1Response( + request_id=request_id, + output=self.model_output_to_completion_output( + output, + model_endpoint, + request.prompt, + request.return_token_log_probs, + ), + ) + elif endpoint_content.inference_framework == LLMInferenceFramework.VLLM: + vllm_args: Any = { + "prompt": request.prompt, + "max_tokens": request.max_new_tokens, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + } + if request.stop_sequences is not None: + vllm_args["stop"] = request.stop_sequences + vllm_args["temperature"] = request.temperature + if request.temperature > 0: + vllm_args["top_k"] = request.top_k + vllm_args["top_p"] = request.top_p + if request.return_token_log_probs: + vllm_args["logprobs"] = 1 + if request.include_stop_str_in_output is not None: + vllm_args["include_stop_str_in_output"] = request.include_stop_str_in_output + if request.guided_choice is not None: + vllm_args["guided_choice"] = request.guided_choice + if request.guided_regex is not None: + vllm_args["guided_regex"] = request.guided_regex + if request.guided_json is not None: + vllm_args["guided_json"] = request.guided_json + if request.guided_grammar is not None: + vllm_args["guided_grammar"] = request.guided_grammar + if request.skip_special_tokens is not None: + vllm_args["skip_special_tokens"] = request.skip_special_tokens + + inference_request = SyncEndpointPredictV1Request( + args=vllm_args, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, + ) + + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + raise UpstreamServiceError( + status_code=500, + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), + ) + + output = json.loads(predict_result.result["result"]) + return CompletionSyncV1Response( + request_id=request_id, + output=self.model_output_to_completion_output( + output, + model_endpoint, + request.prompt, + request.return_token_log_probs, + ), + ) + elif endpoint_content.inference_framework == LLMInferenceFramework.LIGHTLLM: + lightllm_args: Any = { + "inputs": request.prompt, + "parameters": { + "max_new_tokens": request.max_new_tokens, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + }, + } + # TODO: implement stop sequences + if request.temperature > 0: + lightllm_args["parameters"]["temperature"] = request.temperature + lightllm_args["parameters"]["do_sample"] = True + lightllm_args["top_k"] = request.top_k + lightllm_args["top_p"] = request.top_p + else: + lightllm_args["parameters"]["do_sample"] = False + if request.return_token_log_probs: + lightllm_args["parameters"]["return_details"] = True + + inference_request = SyncEndpointPredictV1Request( + args=lightllm_args, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, + ) + + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + raise UpstreamServiceError( + status_code=500, + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), + ) + + output = json.loads(predict_result.result["result"]) + return CompletionSyncV1Response( + request_id=request_id, + output=self.model_output_to_completion_output( + output, + model_endpoint, + request.prompt, + request.return_token_log_probs, + ), + ) + elif endpoint_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: + # TODO: Stop sequences is buggy and return token logprobs are not supported + # TODO: verify the implementation of presence_penalty and repetition_penalty + # and see if they fit our existing definition of presence_penalty and frequency_penalty + # Ref https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/sampling_penalty_kernels.cu + trt_llm_args: Any = { + "text_input": request.prompt, + "max_tokens": request.max_new_tokens, + "stop_words": request.stop_sequences if request.stop_sequences else "", + "bad_words": "", + "temperature": request.temperature, + } + + inference_request = SyncEndpointPredictV1Request( + args=trt_llm_args, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, + ) + + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + raise UpstreamServiceError( + status_code=500, + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), + ) + + output = json.loads(predict_result.result["result"]) + return CompletionSyncV1Response( + request_id=request_id, + output=self.model_output_to_completion_output( + output, + model_endpoint, + request.prompt, + request.return_token_log_probs, + ), + ) + else: + raise EndpointUnsupportedInferenceTypeException( + f"Unsupported inference framework {endpoint_content.inference_framework}" + ) + + +class CompletionStreamV1UseCase: + """ + Use case for running a stream prompt completion on an LLM endpoint. + """ + + def __init__( + self, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + tokenizer_repository: TokenizerRepository, + ): + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + self.authz_module = LiveAuthorizationModule() + self.tokenizer_repository = tokenizer_repository + + async def execute( + self, user: User, model_endpoint_name: str, request: CompletionStreamV1Request + ) -> AsyncIterable[CompletionStreamV1Response]: + """ + Runs the use case to create a stream inference task. + NOTE: Must be called with await(), since the function is not a generator itself, but rather creates one and + returns a reference to it. This structure allows exceptions that occur before response streaming begins + to propagate to the client as HTTP exceptions with the appropriate code. + + Args: + user: The user who is creating the stream inference task. + model_endpoint_name: The name of the model endpoint for the task. + request: The body of the request to forward to the endpoint. + + Returns: + An asynchronous response chunk generator, containing response objects to be iterated through with 'async for'. + Each response object contains the status and result of the task. + + Raises: + ObjectNotFoundException: If a model endpoint with the given name could not be found. + ObjectHasInvalidValueException: If there are multiple model endpoints with the given name. + ObjectNotAuthorizedException: If the owner does not own the model endpoint. + EndpointUnsupportedInferenceTypeException: If the model endpoint does not support streaming or uses + an unsupported inference framework. + UpstreamServiceError: If an error occurs upstream in the streaming inference API call. + InvalidRequestException: If request validation fails during inference. + """ + + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + add_trace_request_id(request_id) + + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) + + if len(model_endpoints) == 0: + raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") + + if len(model_endpoints) > 1: + raise ObjectHasInvalidValueException( + f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" + ) + + add_trace_model_name(model_endpoint_name) + + model_endpoint = model_endpoints[0] + + if not self.authz_module.check_access_read_owned_entity( + user, model_endpoint.record + ) and not self.authz_module.check_endpoint_public_inference_for_user( + user, model_endpoint.record + ): + raise ObjectNotAuthorizedException + + if model_endpoint.record.endpoint_type != ModelEndpointType.STREAMING: + raise EndpointUnsupportedInferenceTypeException( + f"Endpoint {model_endpoint_name} is not a streaming endpoint." + ) + + inference_gateway = ( + self.model_endpoint_service.get_streaming_model_endpoint_inference_gateway() + ) + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint.record.id + ) + + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + validated_request = validate_and_update_completion_params( + model_content.inference_framework, request + ) + if not isinstance(validated_request, CompletionStreamV1Request): + raise ValueError( + f"request has type {validated_request.__class__.__name__}, expected type CompletionStreamV1Request" + ) + request = validated_request + + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) + + args: Any = None + num_prompt_tokens = None + if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: + args = { + "prompts": [request.prompt], + "token_probs": True, + "generate_kwargs": { + "do_sample": True, + "temperature": request.temperature, + "max_new_tokens": request.max_new_tokens, + }, + "serialize_results_as_string": False, + } + if request.stop_sequences is not None: + # Deepspeed models only accepts one stop sequence + args["stop_sequence"] = request.stop_sequences[0] + num_prompt_tokens = count_tokens( + request.prompt, + model_content.model_name, + self.tokenizer_repository, + ) + elif model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: + args = { + "inputs": request.prompt, + "parameters": { + "max_new_tokens": request.max_new_tokens, + }, + } + if request.stop_sequences is not None: + args["parameters"]["stop"] = request.stop_sequences + if request.temperature > 0: + args["parameters"]["temperature"] = request.temperature + args["parameters"]["do_sample"] = True + args["parameters"]["top_k"] = request.top_k + args["parameters"]["top_p"] = request.top_p + else: + args["parameters"]["do_sample"] = False + num_prompt_tokens = count_tokens( + request.prompt, + model_content.model_name, + self.tokenizer_repository, + ) + elif model_content.inference_framework == LLMInferenceFramework.VLLM: + args = { + "prompt": request.prompt, + "max_tokens": request.max_new_tokens, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + } + if request.stop_sequences is not None: + args["stop"] = request.stop_sequences + args["temperature"] = request.temperature + if request.temperature > 0: + args["top_k"] = request.top_k + args["top_p"] = request.top_p + if request.return_token_log_probs: + args["logprobs"] = 1 + if request.include_stop_str_in_output is not None: + args["include_stop_str_in_output"] = request.include_stop_str_in_output + if request.guided_choice is not None: + args["guided_choice"] = request.guided_choice + if request.guided_regex is not None: + args["guided_regex"] = request.guided_regex + if request.guided_json is not None: + args["guided_json"] = request.guided_json + if request.guided_grammar is not None: + args["guided_grammar"] = request.guided_grammar + if request.skip_special_tokens is not None: + args["skip_special_tokens"] = request.skip_special_tokens + args["stream"] = True + elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: + args = { + "inputs": request.prompt, + "parameters": { + "max_new_tokens": request.max_new_tokens, + "presence_penalty": request.presence_penalty, + "frequency_penalty": request.frequency_penalty, + }, + } + # TODO: stop sequences + if request.temperature > 0: + args["parameters"]["temperature"] = request.temperature + args["parameters"]["do_sample"] = True + args["parameters"]["top_k"] = request.top_k + args["parameters"]["top_p"] = request.top_p + else: + args["parameters"]["do_sample"] = False + if request.return_token_log_probs: + args["parameters"]["return_details"] = True + num_prompt_tokens = count_tokens( + request.prompt, + model_content.model_name, + self.tokenizer_repository, + ) + elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: + # TODO: Stop sequences is buggy and return token logprobs are not supported + # TODO: verify the implementation of presence_penalty and repetition_penalty + # and see if they fit our existing definition of presence_penalty and frequency_penalty + # Ref https://github.com/NVIDIA/FasterTransformer/blob/main/src/fastertransformer/kernels/sampling_penalty_kernels.cu + args = { + "text_input": request.prompt, + "max_tokens": request.max_new_tokens, + "stop_words": request.stop_sequences if request.stop_sequences else "", + "bad_words": "", + "temperature": request.temperature, + "stream": True, + } + num_prompt_tokens = count_tokens( + request.prompt, + model_content.model_name, + self.tokenizer_repository, + ) + else: + raise EndpointUnsupportedInferenceTypeException( + f"Unsupported inference framework {model_content.inference_framework}" + ) + + inference_request = SyncEndpointPredictV1Request( + args=args, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + + return self._response_chunk_generator( + request=request, + request_id=request_id, + model_endpoint=model_endpoint, + model_content=model_content, + inference_gateway=inference_gateway, + inference_request=inference_request, + num_prompt_tokens=num_prompt_tokens, + manually_resolve_dns=manually_resolve_dns, + ) + + async def _response_chunk_generator( + self, + request: CompletionStreamV1Request, + request_id: Optional[str], + model_endpoint: ModelEndpoint, + model_content: GetLLMModelEndpointV1Response, + inference_gateway: StreamingModelEndpointInferenceGateway, + inference_request: SyncEndpointPredictV1Request, + num_prompt_tokens: Optional[int], + manually_resolve_dns: bool, + ) -> AsyncIterable[CompletionStreamV1Response]: + """ + Async generator yielding tokens to stream for the completions response. Should only be called when + returned directly by execute(). + """ + predict_result = inference_gateway.streaming_predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, + ) + + num_completion_tokens = 0 + async for res in predict_result: + if not res.status == TaskStatus.SUCCESS or res.result is None: + # Raise an UpstreamServiceError if the task has failed + if res.status == TaskStatus.FAILURE: + raise UpstreamServiceError( + status_code=500, + content=( + res.traceback.encode("utf-8") if res.traceback is not None else b"" + ), + ) + # Otherwise, yield empty response chunk for unsuccessful or empty results + yield CompletionStreamV1Response( + request_id=request_id, + output=None, + ) + else: + result = res.result + # DEEPSPEED + if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: + if "token" in result["result"]: + yield CompletionStreamV1Response( + request_id=request_id, + output=CompletionStreamOutput( + text=result["result"]["token"], + finished=False, + num_prompt_tokens=None, + num_completion_tokens=None, + ), + ) + else: + completion_token_count = len( + result["result"]["response"][0]["token_probs"]["tokens"] + ) + yield CompletionStreamV1Response( + request_id=request_id, + output=CompletionStreamOutput( + text=result["result"]["response"][0]["text"], + finished=True, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=completion_token_count, + ), + ) + # TEXT_GENERATION_INTERFACE + elif ( + model_content.inference_framework + == LLMInferenceFramework.TEXT_GENERATION_INFERENCE + ): + if result["result"].get("generated_text") is not None: + finished = True + else: + finished = False + + num_completion_tokens += 1 + + token = None + if request.return_token_log_probs: + token = TokenOutput( + token=result["result"]["token"]["text"], + log_prob=result["result"]["token"]["logprob"], + ) + try: + yield CompletionStreamV1Response( + request_id=request_id, + output=CompletionStreamOutput( + text=result["result"]["token"]["text"], + finished=finished, + num_prompt_tokens=(num_prompt_tokens if finished else None), + num_completion_tokens=num_completion_tokens, + token=token, + ), + ) + except Exception: + logger.exception( + f"Error parsing text-generation-inference output. Result: {result['result']}" + ) + if result["result"].get("error_type") == "validation": + raise InvalidRequestException( + result["result"].get("error") + ) # trigger a 400 + else: + raise UpstreamServiceError( + status_code=500, content=result.get("error") + ) # also change llms_v1.py that will return a 500 HTTPException so user can retry + # VLLM + elif model_content.inference_framework == LLMInferenceFramework.VLLM: + token = None + if request.return_token_log_probs: + token = TokenOutput( + token=result["result"]["text"], + log_prob=list(result["result"]["log_probs"].values())[0], + ) + finished = result["result"]["finished"] + num_prompt_tokens = result["result"]["count_prompt_tokens"] + yield CompletionStreamV1Response( + request_id=request_id, + output=CompletionStreamOutput( + text=result["result"]["text"], + finished=finished, + num_prompt_tokens=num_prompt_tokens if finished else None, + num_completion_tokens=result["result"]["count_output_tokens"], + token=token, + ), + ) + # LIGHTLLM + elif model_content.inference_framework == LLMInferenceFramework.LIGHTLLM: + token = None + num_completion_tokens += 1 + if request.return_token_log_probs: + token = TokenOutput( + token=result["result"]["token"]["text"], + log_prob=result["result"]["token"]["logprob"], + ) + finished = result["result"]["finished"] + yield CompletionStreamV1Response( + request_id=request_id, + output=CompletionStreamOutput( + text=result["result"]["token"]["text"], + finished=finished, + num_prompt_tokens=num_prompt_tokens if finished else None, + num_completion_tokens=num_completion_tokens, + token=token, + ), + ) + # TENSORRT_LLM + elif model_content.inference_framework == LLMInferenceFramework.TENSORRT_LLM: + num_completion_tokens += 1 + yield CompletionStreamV1Response( + request_id=request_id, + output=CompletionStreamOutput( + text=result["result"]["text_output"], + finished=False, # Tracked by https://github.com/NVIDIA/TensorRT-LLM/issues/240 + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ), + ) + # No else clause needed for an unsupported inference framework, since we check + # model_content.inference_framework in execute() prior to calling _response_chunk_generator, + # raising an exception if it is not one of the frameworks handled above. + + +def validate_endpoint_supports_openai_completion( + endpoint: ModelEndpoint, endpoint_content: GetLLMModelEndpointV1Response +): # pragma: no cover + if endpoint_content.inference_framework not in OPENAI_SUPPORTED_INFERENCE_FRAMEWORKS: + raise EndpointUnsupportedInferenceTypeException( + f"The endpoint's inference framework ({endpoint_content.inference_framework}) does not support openai compatible completion." + ) + + if ( + not isinstance(endpoint.record.current_model_bundle.flavor, RunnableImageLike) + or OPENAI_COMPLETION_PATH not in endpoint.record.current_model_bundle.flavor.extra_routes + ): + raise EndpointUnsupportedRequestException( + "Endpoint does not support v2 openai compatible completion" + ) + + +class CompletionSyncV2UseCase: + """ + Use case for running a v2 openai compatible completion on an LLM endpoint. + """ + + def __init__( + self, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + tokenizer_repository: TokenizerRepository, + ): # pragma: no cover + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + self.authz_module = LiveAuthorizationModule() + self.tokenizer_repository = tokenizer_repository + + async def execute( + self, user: User, model_endpoint_name: str, request: CompletionV2Request + ) -> CompletionV2SyncResponse: # pragma: no cover + """ + Runs the use case to create a sync inference task. + + Args: + user: The user who is creating the sync inference task. + model_endpoint_name: The name of the model endpoint for the task. + request: The body of the request to forward to the endpoint. + + Returns: + A response object that contains the status and result of the task. + + Raises: + ObjectNotFoundException: If a model endpoint with the given name could not be found. + ObjectNotAuthorizedException: If the owner does not own the model endpoint. + """ + + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + add_trace_request_id(request_id) + + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) + + if len(model_endpoints) == 0: + raise ObjectNotFoundException + + if len(model_endpoints) > 1: + raise ObjectHasInvalidValueException( + f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" + ) + + add_trace_model_name(model_endpoint_name) + + model_endpoint = model_endpoints[0] + + if not self.authz_module.check_access_read_owned_entity( + user, model_endpoint.record + ) and not self.authz_module.check_endpoint_public_inference_for_user( + user, model_endpoint.record + ): + raise ObjectNotAuthorizedException + + if ( + model_endpoint.record.endpoint_type is not ModelEndpointType.SYNC + and model_endpoint.record.endpoint_type is not ModelEndpointType.STREAMING + ): + raise EndpointUnsupportedInferenceTypeException( + f"Endpoint {model_endpoint_name} does not serve sync requests." + ) + + inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint.record.id + ) + endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) + + validate_endpoint_supports_openai_completion(model_endpoint, endpoint_content) + + # if inference framework is VLLM, we need to set the model to use the weights folder + if endpoint_content.inference_framework == LLMInferenceFramework.VLLM: + request.model = VLLM_MODEL_WEIGHTS_FOLDER + + inference_request = SyncEndpointPredictV1Request( + args=request.model_dump(exclude_none=True), + destination_path=OPENAI_COMPLETION_PATH, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + try: + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, + ) + + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + raise UpstreamServiceError( + status_code=500, + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), + ) + + output = json.loads(predict_result.result["result"]) + # reset model name to correct value + output["model"] = model_endpoint.record.name + return CompletionV2SyncResponse.model_validate(output) + except UpstreamServiceError as exc: + # Expect upstream inference service to handle bulk of input validation + if 400 <= exc.status_code < 500: + raise InvalidRequestException(exc.content) + raise exc + + +class CompletionStreamV2UseCase: + """ + Use case for running a v2 openai compatible completion on an LLM endpoint. + """ + + def __init__( + self, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + tokenizer_repository: TokenizerRepository, + ): # pragma: no cover + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + self.authz_module = LiveAuthorizationModule() + self.tokenizer_repository = tokenizer_repository + + async def execute( + self, model_endpoint_name: str, request: CompletionV2Request, user: User + ) -> AsyncGenerator[CompletionV2StreamSuccessChunk, None]: # pragma: no cover + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + add_trace_request_id(request_id) + + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) + + if len(model_endpoints) == 0: + raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") + + if len(model_endpoints) > 1: + raise ObjectHasInvalidValueException( + f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" + ) + + add_trace_model_name(model_endpoint_name) + + model_endpoint = model_endpoints[0] + + if not self.authz_module.check_access_read_owned_entity( + user, model_endpoint.record + ) and not self.authz_module.check_endpoint_public_inference_for_user( + user, model_endpoint.record + ): + raise ObjectNotAuthorizedException + + if model_endpoint.record.endpoint_type != ModelEndpointType.STREAMING: + raise EndpointUnsupportedInferenceTypeException( + f"Endpoint {model_endpoint_name} is not a streaming endpoint." + ) + + inference_gateway = ( + self.model_endpoint_service.get_streaming_model_endpoint_inference_gateway() + ) + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint.record.id + ) + + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) + + validate_endpoint_supports_openai_completion(model_endpoint, model_content) + + # if inference framework is VLLM, we need to set the model to use the weights folder + if model_content.inference_framework == LLMInferenceFramework.VLLM: + request.model = VLLM_MODEL_WEIGHTS_FOLDER + + inference_request = SyncEndpointPredictV1Request( + args=request.model_dump(exclude_none=True), + destination_path=OPENAI_COMPLETION_PATH, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + + return self._response_chunk_generator( + request_id=request_id, + model_endpoint=model_endpoint, + model_content=model_content, + inference_gateway=inference_gateway, + inference_request=inference_request, + manually_resolve_dns=manually_resolve_dns, + ) + + async def _response_chunk_generator( + self, + request_id: Optional[str], + model_endpoint: ModelEndpoint, + model_content: GetLLMModelEndpointV1Response, + inference_gateway: StreamingModelEndpointInferenceGateway, + inference_request: SyncEndpointPredictV1Request, + manually_resolve_dns: bool, + ) -> AsyncGenerator[CompletionV2StreamSuccessChunk, None]: # pragma: no cover + """ + Async generator yielding tokens to stream for the completions response. Should only be called when + returned directly by execute(). + """ + try: + predict_result = inference_gateway.streaming_predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, + ) + except UpstreamServiceError as exc: + # Expect upstream inference service to handle bulk of input validation + if 400 <= exc.status_code < 500: + raise InvalidRequestException(str(exc)) + + raise exc + + async for res in predict_result: + if not res.status == TaskStatus.SUCCESS or res.result is None: + raise UpstreamServiceError( + status_code=500, + content=(res.traceback.encode("utf-8") if res.traceback is not None else b""), + ) + else: + result = res.result["result"] + # Reset model name to correct value + if "DONE" in result: + continue + result["model"] = model_endpoint.record.name + yield CompletionV2StreamSuccessChunk.model_validate(result) + + +def validate_endpoint_supports_chat_completion( + endpoint: ModelEndpoint, endpoint_content: GetLLMModelEndpointV1Response +): # pragma: no cover + if endpoint_content.inference_framework not in CHAT_SUPPORTED_INFERENCE_FRAMEWORKS: + raise EndpointUnsupportedInferenceTypeException( + f"The endpoint's inference framework ({endpoint_content.inference_framework}) does not support chat completion." + ) + + if ( + not isinstance(endpoint.record.current_model_bundle.flavor, RunnableImageLike) + or OPENAI_CHAT_COMPLETION_PATH + not in endpoint.record.current_model_bundle.flavor.extra_routes + ): + raise EndpointUnsupportedRequestException("Endpoint does not support chat completion") + + +class ChatCompletionSyncV2UseCase: + """ + Use case for running a chat completion on an LLM endpoint. + """ + + def __init__( + self, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + tokenizer_repository: TokenizerRepository, + ): + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + self.authz_module = LiveAuthorizationModule() + self.tokenizer_repository = tokenizer_repository + + async def execute( + self, user: User, model_endpoint_name: str, request: ChatCompletionV2Request + ) -> ChatCompletionV2SyncResponse: # pragma: no cover + """ + Runs the use case to create a sync inference task. + + Args: + user: The user who is creating the sync inference task. + model_endpoint_name: The name of the model endpoint for the task. + request: The body of the request to forward to the endpoint. + + Returns: + A response object that contains the status and result of the task. + + Raises: + ObjectNotFoundException: If a model endpoint with the given name could not be found. + ObjectNotAuthorizedException: If the owner does not own the model endpoint. + """ + + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + add_trace_request_id(request_id) + + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) + + if len(model_endpoints) == 0: + raise ObjectNotFoundException + + if len(model_endpoints) > 1: + raise ObjectHasInvalidValueException( + f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" + ) + + add_trace_model_name(model_endpoint_name) + + model_endpoint = model_endpoints[0] + + if not self.authz_module.check_access_read_owned_entity( + user, model_endpoint.record + ) and not self.authz_module.check_endpoint_public_inference_for_user( + user, model_endpoint.record + ): + raise ObjectNotAuthorizedException + + if ( + model_endpoint.record.endpoint_type is not ModelEndpointType.SYNC + and model_endpoint.record.endpoint_type is not ModelEndpointType.STREAMING + ): + raise EndpointUnsupportedInferenceTypeException( + f"Endpoint {model_endpoint_name} does not serve sync requests." + ) + + inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint.record.id + ) + endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) + + validate_endpoint_supports_chat_completion(model_endpoint, endpoint_content) + + # if inference framework is VLLM, we need to set the model to use the weights folder + if endpoint_content.inference_framework == LLMInferenceFramework.VLLM: + request.model = VLLM_MODEL_WEIGHTS_FOLDER + + inference_request = SyncEndpointPredictV1Request( + args=request.model_dump(exclude_none=True), + destination_path=OPENAI_CHAT_COMPLETION_PATH, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + try: + predict_result = await inference_gateway.predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, + ) + + if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: + raise UpstreamServiceError( + status_code=500, + content=( + predict_result.traceback.encode("utf-8") + if predict_result.traceback is not None + else b"" + ), + ) + + output = json.loads(predict_result.result["result"]) + # reset model name to correct value + output["model"] = model_endpoint.record.name + return ChatCompletionV2SyncResponse.model_validate(output) + except UpstreamServiceError as exc: + # Expect upstream inference service to handle bulk of input validation + if 400 <= exc.status_code < 500: + raise InvalidRequestException(exc.content) + raise exc + + +class ChatCompletionStreamV2UseCase: + """ + Use case for running a chat completion on an LLM endpoint. + """ + + def __init__( + self, + model_endpoint_service: ModelEndpointService, + llm_model_endpoint_service: LLMModelEndpointService, + tokenizer_repository: TokenizerRepository, + ): + self.model_endpoint_service = model_endpoint_service + self.llm_model_endpoint_service = llm_model_endpoint_service + self.authz_module = LiveAuthorizationModule() + self.tokenizer_repository = tokenizer_repository + + async def execute( + self, model_endpoint_name: str, request: ChatCompletionV2Request, user: User + ) -> AsyncGenerator[ChatCompletionV2StreamSuccessChunk, None]: # pragma: no cover + request_id = LoggerTagManager.get(LoggerTagKey.REQUEST_ID) + add_trace_request_id(request_id) + + model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( + owner=user.team_id, name=model_endpoint_name, order_by=None + ) + + if len(model_endpoints) == 0: + raise ObjectNotFoundException(f"Model endpoint {model_endpoint_name} not found.") + + if len(model_endpoints) > 1: + raise ObjectHasInvalidValueException( + f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" + ) + + add_trace_model_name(model_endpoint_name) + + model_endpoint = model_endpoints[0] + + if not self.authz_module.check_access_read_owned_entity( + user, model_endpoint.record + ) and not self.authz_module.check_endpoint_public_inference_for_user( + user, model_endpoint.record + ): + raise ObjectNotAuthorizedException + + if model_endpoint.record.endpoint_type != ModelEndpointType.STREAMING: + raise EndpointUnsupportedInferenceTypeException( + f"Endpoint {model_endpoint_name} is not a streaming endpoint." + ) + + inference_gateway = ( + self.model_endpoint_service.get_streaming_model_endpoint_inference_gateway() + ) + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint.record.id + ) + + model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) + + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) + validate_endpoint_supports_chat_completion(model_endpoint, model_content) + + # if inference framework is VLLM, we need to set the model to use the weights folder + if model_content.inference_framework == LLMInferenceFramework.VLLM: + request.model = VLLM_MODEL_WEIGHTS_FOLDER + + inference_request = SyncEndpointPredictV1Request( + args=request.model_dump(exclude_none=True), + destination_path=OPENAI_CHAT_COMPLETION_PATH, + num_retries=NUM_DOWNSTREAM_REQUEST_RETRIES, + timeout_seconds=DOWNSTREAM_REQUEST_TIMEOUT_SECONDS, + ) + + return self._response_chunk_generator( + request_id=request_id, + model_endpoint=model_endpoint, + model_content=model_content, + inference_gateway=inference_gateway, + inference_request=inference_request, + manually_resolve_dns=manually_resolve_dns, + ) + + async def _response_chunk_generator( + self, + request_id: Optional[str], + model_endpoint: ModelEndpoint, + model_content: GetLLMModelEndpointV1Response, + inference_gateway: StreamingModelEndpointInferenceGateway, + inference_request: SyncEndpointPredictV1Request, + manually_resolve_dns: bool, + ) -> AsyncGenerator[ChatCompletionV2StreamSuccessChunk, None]: + """ + Async generator yielding tokens to stream for the completions response. Should only be called when + returned directly by execute(). + """ + try: + predict_result = inference_gateway.streaming_predict( + topic=model_endpoint.record.destination, + predict_request=inference_request, + manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, + ) + except UpstreamServiceError as exc: + # Expect upstream inference service to handle bulk of input validation + if 400 <= exc.status_code < 500: + raise InvalidRequestException(str(exc)) + + raise exc + + async for res in predict_result: + if not res.status == TaskStatus.SUCCESS or res.result is None: + raise UpstreamServiceError( + status_code=500, + content=(res.traceback.encode("utf-8") if res.traceback is not None else b""), + ) + else: + result = res.result["result"] + # Reset model name to correct value + if "DONE" in result: + continue + result["model"] = model_endpoint.record.name + yield ChatCompletionV2StreamSuccessChunk.model_validate(result) + + +class ModelDownloadV1UseCase: + def __init__( + self, + filesystem_gateway: FilesystemGateway, + model_endpoint_service: ModelEndpointService, + llm_artifact_gateway: LLMArtifactGateway, + ): + self.filesystem_gateway = filesystem_gateway + self.model_endpoint_service = model_endpoint_service + self.llm_artifact_gateway = llm_artifact_gateway + + async def execute(self, user: User, request: ModelDownloadRequest) -> ModelDownloadResponse: + model_endpoints = await self.model_endpoint_service.list_model_endpoints( + owner=user.team_id, name=request.model_name, order_by=None + ) + if len(model_endpoints) == 0: + raise ObjectNotFoundException + + if len(model_endpoints) > 1: + raise ObjectHasInvalidValueException( + f"Expected 1 LLM model endpoint for model name {request.model_name}, got {len(model_endpoints)}" + ) + model_files = self.llm_artifact_gateway.get_model_weights_urls( + user.team_id, request.model_name + ) + urls = {} + for model_file in model_files: + # don't want to make s3 bucket full keys public, so trim to just keep file name + public_file_name = model_file.rsplit("/", 1)[-1] + urls[public_file_name] = self.filesystem_gateway.generate_signed_url(model_file) + return ModelDownloadResponse(urls=urls) + + +async def _fill_hardware_info( + llm_artifact_gateway: LLMArtifactGateway, request: CreateLLMModelEndpointV1Request +): + if ( + request.gpus is None + or request.gpu_type is None + or request.cpus is None + or request.memory is None + or request.storage is None + or request.nodes_per_worker is None + ): + if not ( + request.gpus is None + and request.gpu_type is None + and request.cpus is None + and request.memory is None + and request.storage is None + and request.nodes_per_worker is None + ): + raise ObjectHasInvalidValueException( + "All hardware spec fields (gpus, gpu_type, cpus, memory, storage, nodes_per_worker) must be provided if any hardware spec field is missing." + ) + checkpoint_path = get_checkpoint_path(request.model_name, request.checkpoint_path) + hardware_info = await _infer_hardware( + llm_artifact_gateway, request.model_name, checkpoint_path + ) + request.gpus = hardware_info.gpus + request.gpu_type = hardware_info.gpu_type + request.cpus = hardware_info.cpus + request.memory = hardware_info.memory + request.storage = hardware_info.storage + request.nodes_per_worker = hardware_info.nodes_per_worker + if hardware_info.gpus: # make lint happy + request.num_shards = hardware_info.gpus + + +def get_model_param_count_b(model_name: str) -> int: + """Get the number of parameters in the model in billions""" + if "mixtral-8x7b" in model_name: + model_param_count_b = 47 + elif "mixtral-8x22b" in model_name: + model_param_count_b = 140 + elif "phi-3-mini" in model_name: + model_param_count_b = 4 + elif "phi-3-small" in model_name: + model_param_count_b = 8 + elif "phi-3-medium" in model_name: + model_param_count_b = 15 + elif "deepseek-coder-v2-lite" in model_name: + model_param_count_b = 16 + elif "deepseek-coder-v2" in model_name: + model_param_count_b = 237 + else: + numbers = re.findall(r"(\d+)b", model_name) + if len(numbers) == 0: + raise ObjectHasInvalidValueException( + f"Unable to infer number of parameters for {model_name}." + ) + model_param_count_b = int(numbers[-1]) + return model_param_count_b + + +@lru_cache() +async def _infer_hardware( + llm_artifact_gateway: LLMArtifactGateway, + model_name: str, + checkpoint_path: str, + is_batch_job: bool = False, + max_context_length: Optional[int] = None, +) -> CreateDockerImageBatchJobResourceRequests: + config = llm_artifact_gateway.get_model_config(checkpoint_path) + + dtype_size = 2 + kv_multiplier = 20 if is_batch_job else 2 + + max_position_embeddings = ( + min(max_context_length, config["max_position_embeddings"]) + if max_context_length + else config["max_position_embeddings"] + ) + + min_kv_cache_size = ( + kv_multiplier + * dtype_size + * config["num_hidden_layers"] + * config["hidden_size"] + * max_position_embeddings + // (config["num_attention_heads"] // config["num_key_value_heads"]) + ) + + model_param_count_b = get_model_param_count_b(model_name) + model_weights_size = dtype_size * model_param_count_b * 1_000_000_000 + + min_memory_gb = math.ceil((min_kv_cache_size + model_weights_size) / 1_000_000_000 / 0.9) + + logger.info( + f"Memory calculation result: {min_memory_gb=} for {model_name} context_size: {max_position_embeddings}, min_kv_cache_size: {min_kv_cache_size}, model_weights_size: {model_weights_size}, is_batch_job: {is_batch_job}" + ) + + config_map = await _get_recommended_hardware_config_map() + by_model_name = {item["name"]: item for item in yaml.safe_load(config_map["byModelName"])} + by_gpu_memory_gb = yaml.safe_load(config_map["byGpuMemoryGb"]) + if model_name in by_model_name: + cpus = by_model_name[model_name]["cpus"] + gpus = by_model_name[model_name]["gpus"] + memory = by_model_name[model_name]["memory"] + storage = by_model_name[model_name]["storage"] + gpu_type = by_model_name[model_name]["gpu_type"] + nodes_per_worker = by_model_name[model_name]["nodes_per_worker"] + else: + by_gpu_memory_gb = sorted(by_gpu_memory_gb, key=lambda x: x["gpu_memory_le"]) + for recs in by_gpu_memory_gb: + if min_memory_gb <= recs["gpu_memory_le"]: + cpus = recs["cpus"] + gpus = recs["gpus"] + memory = recs["memory"] + storage = recs["storage"] + gpu_type = recs["gpu_type"] + nodes_per_worker = recs["nodes_per_worker"] + break + else: + raise ObjectHasInvalidValueException(f"Unable to infer hardware for {model_name}.") + + return CreateDockerImageBatchJobResourceRequests( + cpus=cpus, + gpus=gpus, + memory=memory, + storage=storage, + gpu_type=gpu_type, + nodes_per_worker=nodes_per_worker, + ) + + +def infer_addition_engine_args_from_model_name( + model_name: str, +) -> VLLMEndpointAdditionalArgs: + # Increase max gpu utilization for larger models + gpu_memory_utilization = 0.9 + try: + model_param_count_b = get_model_param_count_b(model_name) + if model_param_count_b >= 70: + gpu_memory_utilization = 0.95 + except ObjectHasInvalidValueException: # pragma: no cover + pass + + # Gemma 2 requires flashinfer attention backend + attention_backend = None + if model_name.startswith("gemma-2"): + attention_backend = "FLASHINFER" + + trust_remote_code = None + # DeepSeek requires trust_remote_code + if model_name.startswith("deepseek"): + trust_remote_code = True + + return VLLMEndpointAdditionalArgs( + gpu_memory_utilization=gpu_memory_utilization, + attention_backend=attention_backend, + trust_remote_code=trust_remote_code, + ) + + +class CreateBatchCompletionsUseCase: + def __init__( + self, + docker_image_batch_job_gateway: DockerImageBatchJobGateway, + docker_repository: DockerRepository, + docker_image_batch_job_bundle_repo: DockerImageBatchJobBundleRepository, + llm_artifact_gateway: LLMArtifactGateway, + ): + self.docker_image_batch_job_gateway = docker_image_batch_job_gateway + self.docker_repository = docker_repository + self.docker_image_batch_job_bundle_repo = docker_image_batch_job_bundle_repo + self.llm_artifact_gateway = llm_artifact_gateway + + async def create_batch_job_bundle( + self, + user: User, + request: CreateBatchCompletionsEngineRequest, + hardware: CreateDockerImageBatchJobResourceRequests, + ) -> DockerImageBatchJobBundle: + assert hardware.gpu_type is not None + + bundle_name = ( + f"{request.model_cfg.model}_{datetime.datetime.utcnow().strftime('%y%m%d-%H%M%S')}" + ) + + image_tag = await _get_latest_batch_tag(LLMInferenceFramework.VLLM) + + config_file_path = "/opt/config.json" + + batch_bundle = ( + await self.docker_image_batch_job_bundle_repo.create_docker_image_batch_job_bundle( + name=bundle_name, + created_by=user.user_id, + owner=user.team_id, + image_repository=hmi_config.batch_inference_vllm_repository, + image_tag=image_tag, + command=[ + "dumb-init", + "--", + "/bin/bash", + "-c", + "ddtrace-run python vllm_batch.py", + ], + env={"CONFIG_FILE": config_file_path}, + mount_location=config_file_path, + cpus=str(hardware.cpus), + memory=str(hardware.memory), + storage=str(hardware.storage), + gpus=hardware.gpus, + gpu_type=hardware.gpu_type, + public=False, + ) + ) + return batch_bundle + + async def execute( + self, user: User, request: CreateBatchCompletionsV1Request + ) -> CreateBatchCompletionsV1Response: + if ( + request.data_parallelism is not None and request.data_parallelism > 1 + ): # pragma: no cover + raise ObjectHasInvalidValueException( + "Data parallelism is disabled for batch completions." + ) + + request.model_cfg.checkpoint_path = get_checkpoint_path( + request.model_cfg.model, request.model_cfg.checkpoint_path + ) + hardware = await _infer_hardware( + self.llm_artifact_gateway, + request.model_cfg.model, + request.model_cfg.checkpoint_path, + is_batch_job=True, + max_context_length=request.model_cfg.max_context_length, + ) + assert hardware.gpus is not None + + engine_request = CreateBatchCompletionsEngineRequest.from_api_v1(request) + engine_request.model_cfg.num_shards = hardware.gpus + if engine_request.tool_config and engine_request.tool_config.name != "code_evaluator": + raise ObjectHasInvalidValueException( + "Only code_evaluator tool is supported for batch completions." + ) + + additional_engine_args = infer_addition_engine_args_from_model_name( + engine_request.model_cfg.model + ) + + engine_request.max_gpu_memory_utilization = additional_engine_args.gpu_memory_utilization + engine_request.attention_backend = additional_engine_args.attention_backend + + batch_bundle = await self.create_batch_job_bundle(user, engine_request, hardware) + + validate_resource_requests( + bundle=batch_bundle, + cpus=hardware.cpus, + memory=hardware.memory, + storage=hardware.storage, + gpus=hardware.gpus, + gpu_type=hardware.gpu_type, + ) + + if ( + engine_request.max_runtime_sec is None or engine_request.max_runtime_sec < 1 + ): # pragma: no cover + raise ObjectHasInvalidValueException("max_runtime_sec must be a positive integer.") + + job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( + created_by=user.user_id, + owner=user.team_id, + job_config=engine_request.model_dump(by_alias=True), + env=batch_bundle.env, + command=batch_bundle.command, + repo=batch_bundle.image_repository, + tag=batch_bundle.image_tag, + resource_requests=hardware, + labels=engine_request.labels, + mount_location=batch_bundle.mount_location, + override_job_max_runtime_s=engine_request.max_runtime_sec, + num_workers=engine_request.data_parallelism, + ) + return CreateBatchCompletionsV1Response(job_id=job_id) + + +class CreateBatchCompletionsV2UseCase: + def __init__( + self, + llm_batch_completions_service: LLMBatchCompletionsService, + llm_artifact_gateway: LLMArtifactGateway, + ): + self.llm_batch_completions_service = llm_batch_completions_service + self.llm_artifact_gateway = llm_artifact_gateway + + async def execute( + self, request: CreateBatchCompletionsV2Request, user: User + ) -> CreateBatchCompletionsV2Response: + request.model_cfg.checkpoint_path = get_checkpoint_path( + request.model_cfg.model, request.model_cfg.checkpoint_path + ) + + if ( + request.cpus is not None + and request.gpus is not None + and request.memory is not None + and request.storage is not None + and request.gpu_type is not None + ): + hardware = CreateDockerImageBatchJobResourceRequests( + cpus=request.cpus, + gpus=request.gpus, + memory=request.memory, + storage=request.storage, + gpu_type=request.gpu_type, + ) + else: + if ( + request.cpus is not None + or request.gpus is not None + or request.memory is not None + or request.storage is not None + or request.gpu_type is not None + ): + logger.warning( + "All hardware spec fields (cpus, gpus, memory, storage, gpu_type) must be provided if any hardware spec field is provided. Will attempt to infer hardware spec from checkpoint." + ) + + hardware = await _infer_hardware( + self.llm_artifact_gateway, + request.model_cfg.model, + request.model_cfg.checkpoint_path, + is_batch_job=True, + max_context_length=request.model_cfg.max_context_length, + ) + + engine_request = CreateBatchCompletionsEngineRequest.from_api_v2(request) + engine_request.model_cfg.num_shards = hardware.gpus + + validate_resource_requests( + bundle=None, + cpus=hardware.cpus, + memory=hardware.memory, + storage=hardware.storage, + gpus=hardware.gpus, + gpu_type=hardware.gpu_type, + ) + + if engine_request.max_runtime_sec is None or engine_request.max_runtime_sec < 1: + raise ObjectHasInvalidValueException("max_runtime_sec must be a positive integer.") + + # Right now we only support VLLM for batch inference. Refactor this if we support more inference frameworks. + image_repo = hmi_config.batch_inference_vllm_repository + image_tag = await _get_latest_batch_v2_tag(LLMInferenceFramework.VLLM) + + additional_engine_args = infer_addition_engine_args_from_model_name( + engine_request.model_cfg.model + ) + + # Overwrite model config fields with those determined by additional engine args + for field in VLLMModelConfig.model_fields.keys(): + config_value = getattr(additional_engine_args, field, None) + if config_value is not None and hasattr(engine_request.model_cfg, field): + setattr(engine_request.model_cfg, field, config_value) + + engine_request.attention_backend = additional_engine_args.attention_backend + + return await self.llm_batch_completions_service.create_batch_job( + user=user, + job_request=engine_request, + image_repo=image_repo, + image_tag=image_tag, + resource_requests=hardware, + labels=engine_request.labels, + max_runtime_sec=engine_request.max_runtime_sec, + num_workers=engine_request.data_parallelism, + ) + + +class GetBatchCompletionV2UseCase: + def __init__(self, llm_batch_completions_service: LLMBatchCompletionsService): + self.llm_batch_completions_service = llm_batch_completions_service + + async def execute( + self, + batch_completion_id: str, + user: User, + ) -> GetBatchCompletionV2Response: + job = await self.llm_batch_completions_service.get_batch_job( + batch_completion_id, + user=user, + ) + + if not job: + raise ObjectNotFoundException(f"Batch completion {batch_completion_id} not found.") + + return GetBatchCompletionV2Response(job=job) + + +class UpdateBatchCompletionV2UseCase: + def __init__(self, llm_batch_completions_service: LLMBatchCompletionsService): + self.llm_batch_completions_service = llm_batch_completions_service + + async def execute( + self, + batch_completion_id: str, + request: UpdateBatchCompletionsV2Request, + user: User, + ) -> UpdateBatchCompletionsV2Response: + result = await self.llm_batch_completions_service.update_batch_job( + batch_completion_id, + user=user, + request=request, + ) + if not result: + raise ObjectNotFoundException(f"Batch completion {batch_completion_id} not found.") + + return UpdateBatchCompletionsV2Response( + **result.model_dump(by_alias=True, exclude_none=True), + success=True, + ) + + +class CancelBatchCompletionV2UseCase: + def __init__(self, llm_batch_completions_service: LLMBatchCompletionsService): + self.llm_batch_completions_service = llm_batch_completions_service + + async def execute( + self, + batch_completion_id: str, + user: User, + ) -> CancelBatchCompletionsV2Response: + return CancelBatchCompletionsV2Response( + success=await self.llm_batch_completions_service.cancel_batch_job( + batch_completion_id, + user=user, + ) + ) diff --git a/server/llm_engine_server/domain/use_cases/model_bundle_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_bundle_use_cases.py similarity index 94% rename from server/llm_engine_server/domain/use_cases/model_bundle_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/model_bundle_use_cases.py index aa69bbde..d79b8793 100644 --- a/server/llm_engine_server/domain/use_cases/model_bundle_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_bundle_use_cases.py @@ -1,7 +1,7 @@ from typing import Optional, Union from uuid import uuid4 -from llm_engine_server.common.dtos.model_bundles import ( +from model_engine_server.common.dtos.model_bundles import ( CloneModelBundleV1Request, CloneModelBundleV2Request, CreateModelBundleV1Request, @@ -15,16 +15,11 @@ ModelBundleV1Response, ModelBundleV2Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( - DockerImageNotFoundException, - ObjectNotAuthorizedException, - ObjectNotFoundException, +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, ) -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, -) -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( ArtifactLike, CloudpickleArtifactFlavor, CustomFramework, @@ -37,8 +32,13 @@ TensorflowFramework, ZipArtifactFlavor, ) -from llm_engine_server.domain.gateways import ModelPrimitiveGateway -from llm_engine_server.domain.repositories import DockerRepository, ModelBundleRepository +from model_engine_server.domain.exceptions import ( + DockerImageNotFoundException, + ObjectNotAuthorizedException, + ObjectNotFoundException, +) +from model_engine_server.domain.gateways import ModelPrimitiveGateway +from model_engine_server.domain.repositories import DockerRepository, ModelBundleRepository class CreateModelBundleV1UseCase: @@ -52,7 +52,7 @@ def __init__( docker_repository: DockerRepository, model_primitive_gateway: ModelPrimitiveGateway, ): - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() self.model_bundle_repository = model_bundle_repository self.docker_repository = docker_repository self.model_primitive_gateway = model_primitive_gateway @@ -145,7 +145,7 @@ async def execute( load_predict_fn_module_path=metadata.get("load_predict_fn_module_path", ""), load_model_fn_module_path=metadata.get("load_model_fn_module_path", ""), ) - else: # request.packaging_type == ModelBundlePackagingType.CLOUDPICKLE: + else: # request.packaging_type == ModelBundlePackagingType.LIRA: flavor = RunnableImageFlavor( flavor=ModelBundleFlavorType.RUNNABLE_IMAGE, repository="", # stub value, not used @@ -182,7 +182,7 @@ class CloneModelBundleV1UseCase: def __init__(self, model_bundle_repository: ModelBundleRepository): self.model_bundle_repository = model_bundle_repository - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, @@ -280,7 +280,7 @@ class GetModelBundleByIdV1UseCase: def __init__(self, model_bundle_repository: ModelBundleRepository): self.model_bundle_repository = model_bundle_repository - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute(self, user: User, model_bundle_id: str) -> ModelBundleV1Response: """ @@ -346,13 +346,16 @@ def __init__( docker_repository: DockerRepository, model_primitive_gateway: ModelPrimitiveGateway, ): - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() self.model_bundle_repository = model_bundle_repository self.docker_repository = docker_repository self.model_primitive_gateway = model_primitive_gateway async def execute( - self, user: User, request: CreateModelBundleV2Request + self, + user: User, + request: CreateModelBundleV2Request, + do_auth_check: bool = True, ) -> CreateModelBundleV2Response: """ Runs the use case to create a Model Bundle. @@ -360,6 +363,9 @@ async def execute( Args: user: The user who is creating the Model Bundle. request: A request object that contains the creation fields. + do_auth_check: Whether we should run the auth check. We're skipping the check + inside of the llm endpoint creation use case. This is fine as long as that use case + isn't directly exposed to the outside. Returns: A response object that contains the creation response fields. @@ -396,7 +402,7 @@ async def execute( tag=request.flavor.tag, ) - if not self.authz_module.check_access_create_bundle_v2(user, request): + if do_auth_check and not self.authz_module.check_access_create_bundle_v2(user, request): raise ObjectNotAuthorizedException created_by = user.user_id @@ -428,14 +434,14 @@ async def execute( ) app_config = request.flavor.app_config else: - location = "unused" # Nonempty to support legacy LLMEngine + location = "unused" # Nonempty to support legacy Launch requirements = [] env_params = { "framework_type": ModelBundleFrameworkType.CUSTOM, "ecr_repo": request.flavor.repository, "image_tag": request.flavor.tag, } - packaging_type = ModelBundlePackagingType.CLOUDPICKLE + packaging_type = ModelBundlePackagingType.LIRA app_config = None model_bundle = await self.model_bundle_repository.create_model_bundle( @@ -464,7 +470,7 @@ class CloneModelBundleV2UseCase: def __init__(self, model_bundle_repository: ModelBundleRepository): self.model_bundle_repository = model_bundle_repository - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, @@ -561,7 +567,7 @@ class GetModelBundleByIdV2UseCase: def __init__(self, model_bundle_repository: ModelBundleRepository): self.model_bundle_repository = model_bundle_repository - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute(self, user: User, model_bundle_id: str) -> ModelBundleV2Response: """ diff --git a/server/llm_engine_server/domain/use_cases/model_endpoint_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py similarity index 68% rename from server/llm_engine_server/domain/use_cases/model_endpoint_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py index 4cf04c16..69beac00 100644 --- a/server/llm_engine_server/domain/use_cases/model_endpoint_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoint_use_cases.py @@ -4,10 +4,12 @@ Read model endpoint creation logs: GET model-endpoints//creation-logs """ -from typing import List, Optional +import re +from dataclasses import dataclass +from typing import Any, Dict, List, Optional -from llm_engine_server.common.constants import CALLBACK_POST_INFERENCE_HOOK -from llm_engine_server.common.dtos.model_endpoints import ( +from model_engine_server.common.constants import SUPPORTED_POST_INFERENCE_HOOKS +from model_engine_server.common.dtos.model_endpoints import ( CreateModelEndpointV1Request, CreateModelEndpointV1Response, DeleteModelEndpointV1Response, @@ -17,33 +19,39 @@ UpdateModelEndpointV1Request, UpdateModelEndpointV1Response, ) -from llm_engine_server.common.resource_limits import MAX_ENDPOINT_SIZE, validate_resource_requests -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( - ObjectHasInvalidValueException, - ObjectNotAuthorizedException, - ObjectNotFoundException, +from model_engine_server.common.resource_limits import MAX_ENDPOINT_SIZE, validate_resource_requests +from model_engine_server.common.settings import REQUIRED_ENDPOINT_LABELS, RESTRICTED_ENDPOINT_LABELS +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, ) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, -) -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( + ModelBundle, ModelEndpoint, ModelEndpointType, + RunnableImageFlavor, StreamingEnhancedRunnableImageFlavor, ) -from llm_engine_server.domain.exceptions import ( +from model_engine_server.domain.exceptions import ( + EndpointBillingTagsMalformedException, EndpointInfraStateNotFound, EndpointLabelsException, EndpointResourceInvalidRequestException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, + PostInferenceHooksException, ) -from llm_engine_server.domain.repositories import ModelBundleRepository -from llm_engine_server.domain.services import ModelEndpointService +from model_engine_server.domain.repositories import ModelBundleRepository +from model_engine_server.domain.services import ModelEndpointService CONVERTED_FROM_ARTIFACT_LIKE_KEY = "_CONVERTED_FROM_ARTIFACT_LIKE" +MODEL_BUNDLE_CHANGED_KEY = "_MODEL_BUNDLE_CHANGED" + +DEFAULT_DISALLOWED_TEAMS = ["_INVALID_TEAM"] -logger = make_logger(filename_wo_ext(__name__)) +logger = make_logger(logger_name()) def model_endpoint_entity_to_get_model_endpoint_response( @@ -96,10 +104,13 @@ def validate_deployment_resources( min_workers: Optional[int], max_workers: Optional[int], endpoint_type: ModelEndpointType, + can_scale_http_endpoint_from_zero: bool, ) -> None: # TODO: we should be also validating the update request against the existing state in k8s (e.g. # so min_workers <= max_workers always) maybe this occurs already in update_model_endpoint. - min_endpoint_size = 0 if endpoint_type == ModelEndpointType.ASYNC else 1 + min_endpoint_size = ( + 0 if endpoint_type == ModelEndpointType.ASYNC or can_scale_http_endpoint_from_zero else 1 + ) if min_workers is not None and min_workers < min_endpoint_size: raise EndpointResourceInvalidRequestException( f"Requested min workers {min_workers} too low" @@ -110,6 +121,86 @@ def validate_deployment_resources( ) +@dataclass +class ValidationResult: + passed: bool + message: str + + +# Placeholder team and product label validator that only checks for a single invalid team +def simple_team_product_validator(team: str, product: str) -> ValidationResult: + if team in DEFAULT_DISALLOWED_TEAMS: + return ValidationResult(False, "Invalid team") + else: + return ValidationResult(True, "Valid team") + + +def validate_labels(labels: Dict[str, str]) -> None: + for required_label in REQUIRED_ENDPOINT_LABELS: + if required_label not in labels: + raise EndpointLabelsException( + f"Missing label '{required_label}' in labels. These are all required: {REQUIRED_ENDPOINT_LABELS}", + ) + + for restricted_label in RESTRICTED_ENDPOINT_LABELS: + if restricted_label in labels: + raise EndpointLabelsException(f"Cannot specify '{restricted_label}' in labels") + + # TODO: remove after we fully migrate to the new team + product validator + try: + from plugins.known_users import ALLOWED_TEAMS + + # Make sure that the team is one of the values from a canonical set. + if labels["team"] not in ALLOWED_TEAMS: + raise EndpointLabelsException(f"Invalid team label, must be one of: {ALLOWED_TEAMS}") + except ModuleNotFoundError: + pass + + try: + from shared_plugins.team_product_label_validation import validate_team_product_label + except ModuleNotFoundError: + validate_team_product_label = simple_team_product_validator + + validation_result = validate_team_product_label(labels["team"], labels["product"]) + if not validation_result.passed: + raise EndpointLabelsException(validation_result.message) + + # Check k8s will accept the label values + regex_pattern = "(([A-Za-z0-9][-A-Za-z0-9_.]*)?[A-Za-z0-9])?" # k8s label regex + for label_value in labels.values(): + if re.fullmatch(regex_pattern, label_value) is None: + raise EndpointLabelsException( + f"Invalid label value {label_value}, must match regex {regex_pattern}" + ) + + +def validate_billing_tags(billing_tags: Optional[Dict[str, Any]]) -> None: + if billing_tags is None: + return + + if type(billing_tags) is not dict: + raise EndpointBillingTagsMalformedException("Billing tags must be a json dictionary") + + required_keys = { + "idempotencyKeyPrefix", + "product", + "type", + "subType", + "payee", + "payor", + "reference", + } + + missing_keys = required_keys - set(billing_tags) + if len(missing_keys) > 0: + raise EndpointBillingTagsMalformedException(f"Missing billing tag keys {missing_keys}") + for k, v in billing_tags.items(): + if type(k) is not str or type(v) not in [str, dict]: + raise EndpointBillingTagsMalformedException( + "Billing tags must have string keys and string/dict values" + ) + + def validate_post_inference_hooks(user: User, post_inference_hooks: Optional[List[str]]) -> None: # We're going to ask for user-specified auth for callbacks instead of providing default auth # from Launch. Otherwise, we'd want to prevent non-privileged users from using the @@ -118,10 +209,51 @@ def validate_post_inference_hooks(user: User, post_inference_hooks: Optional[Lis return for hook in post_inference_hooks: - if hook not in [ - CALLBACK_POST_INFERENCE_HOOK, - ]: - raise ValueError(f"Unsupported post-inference hook {hook}") + if hook not in SUPPORTED_POST_INFERENCE_HOOKS: + raise PostInferenceHooksException( + f"Unsupported post-inference hook {hook}. The supported hooks are: {SUPPORTED_POST_INFERENCE_HOOKS}" + ) + + +def validate_bundle_multinode_compatibility(bundle: ModelBundle, nodes_per_worker: int): + """ + Only some bundles can be multinode compatible. + """ + if nodes_per_worker == 1: + return + # can type ignore, bundle.flavor is a RunnableImageFlavor/StreamingEnhancedRunnableImageFlavor thus it has worker_command and worker_env + if ( + type(bundle.flavor) in {RunnableImageFlavor, StreamingEnhancedRunnableImageFlavor} + and bundle.flavor.worker_command is not None # type: ignore + and bundle.flavor.worker_env is not None # type: ignore + ): + return + raise ObjectHasInvalidValueException( + f"Bundle {bundle.name} is not multinode compatible. It must be a RunnableImage and have worker_command and worker_args set." + ) + + +def validate_endpoint_resource_multinode_compatibility( + gpu_type: Optional[str], + gpus: Optional[int], + endpoint_type: ModelEndpointType, + nodes_per_worker: int, +): + """ + Only gpu streaming endpoints can be multinode compatible. + """ + if nodes_per_worker == 1: + return + if ( + endpoint_type == ModelEndpointType.STREAMING + and gpu_type is not None + and gpus is not None + and gpus > 0 + ): + return + raise ObjectHasInvalidValueException( + "Endpoint is not multinode compatible. Only streaming GPU endpoints can be multinode compatible." + ) class CreateModelEndpointV1UseCase: @@ -132,7 +264,7 @@ def __init__( ): self.model_bundle_repository = model_bundle_repository self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, request: CreateModelEndpointV1Request @@ -141,15 +273,23 @@ async def execute( min_workers=request.min_workers, max_workers=request.max_workers, endpoint_type=request.endpoint_type, + can_scale_http_endpoint_from_zero=self.model_endpoint_service.can_scale_http_endpoint_from_zero(), ) if request.labels is None: raise EndpointLabelsException("Endpoint labels cannot be None!") + validate_labels(request.labels) + validate_billing_tags(request.billing_tags) validate_post_inference_hooks(user, request.post_inference_hooks) bundle = await self.model_bundle_repository.get_model_bundle( model_bundle_id=request.model_bundle_id ) + if bundle is None: raise ObjectNotFoundException + validate_bundle_multinode_compatibility(bundle, request.nodes_per_worker) + validate_endpoint_resource_multinode_compatibility( + request.gpu_type, request.gpus, request.endpoint_type, request.nodes_per_worker + ) if not self.authz_module.check_access_read_owned_entity(user, bundle): raise ObjectNotAuthorizedException if not isinstance(bundle.flavor, StreamingEnhancedRunnableImageFlavor) and ( @@ -207,6 +347,7 @@ async def execute( memory=request.memory, gpu_type=request.gpu_type, storage=request.storage, + nodes_per_worker=request.nodes_per_worker, optimize_costs=bool(request.optimize_costs), min_workers=request.min_workers, max_workers=request.max_workers, @@ -240,11 +381,14 @@ def __init__( ): self.model_bundle_repository = model_bundle_repository self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( self, user: User, model_endpoint_id: str, request: UpdateModelEndpointV1Request ) -> UpdateModelEndpointV1Response: + if request.labels is not None: + validate_labels(request.labels) + validate_billing_tags(request.billing_tags) validate_post_inference_hooks(user, request.post_inference_hooks) endpoint = await self.model_endpoint_service.get_model_endpoint( @@ -290,19 +434,29 @@ async def execute( # infra_state to make sure that after the update, all resources are valid and in sync. # E.g. If user only want to update gpus and leave gpu_type as None, we use the existing gpu_type # from infra_state to avoid passing in None to validate_resource_requests. + raw_request = request.dict(exclude_unset=True) validate_resource_requests( bundle=bundle, - cpus=request.cpus or infra_state.resource_state.cpus, - memory=request.memory or infra_state.resource_state.memory, - storage=request.storage or infra_state.resource_state.storage, - gpus=request.gpus or infra_state.resource_state.gpus, - gpu_type=request.gpu_type or infra_state.resource_state.gpu_type, + cpus=(request.cpus if "cpus" in raw_request else infra_state.resource_state.cpus), + memory=( + request.memory if "memory" in raw_request else infra_state.resource_state.memory + ), + storage=( + request.storage if "storage" in raw_request else infra_state.resource_state.storage + ), + gpus=(request.gpus if "gpus" in raw_request else infra_state.resource_state.gpus), + gpu_type=( + request.gpu_type + if "gpu_type" in raw_request + else infra_state.resource_state.gpu_type + ), ) validate_deployment_resources( min_workers=request.min_workers, max_workers=request.max_workers, endpoint_type=endpoint_record.endpoint_type, + can_scale_http_endpoint_from_zero=self.model_endpoint_service.can_scale_http_endpoint_from_zero(), ) if request.metadata is not None and CONVERTED_FROM_ARTIFACT_LIKE_KEY in request.metadata: @@ -381,7 +535,7 @@ class GetModelEndpointByIdV1UseCase: def __init__(self, model_endpoint_service: ModelEndpointService): self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute(self, user: User, model_endpoint_id: str) -> GetModelEndpointV1Response: """ @@ -409,7 +563,7 @@ async def execute(self, user: User, model_endpoint_id: str) -> GetModelEndpointV class DeleteModelEndpointByIdV1UseCase: def __init__(self, model_endpoint_service: ModelEndpointService): self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute(self, user: User, model_endpoint_id: str) -> DeleteModelEndpointV1Response: model_endpoint = await self.model_endpoint_service.get_model_endpoint_record( diff --git a/server/llm_engine_server/domain/use_cases/model_endpoints_schema_use_cases.py b/model-engine/model_engine_server/domain/use_cases/model_endpoints_schema_use_cases.py similarity index 64% rename from server/llm_engine_server/domain/use_cases/model_endpoints_schema_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/model_endpoints_schema_use_cases.py index 5ebe873c..c35ce456 100644 --- a/server/llm_engine_server/domain/use_cases/model_endpoints_schema_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/model_endpoints_schema_use_cases.py @@ -1,9 +1,9 @@ -from llm_engine_server.common.dtos.model_endpoints import GetModelEndpointsSchemaV1Response -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, +from model_engine_server.common.dtos.model_endpoints import GetModelEndpointsSchemaV1Response +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, ) -from llm_engine_server.domain.services import ModelEndpointService +from model_engine_server.domain.services import ModelEndpointService class GetModelEndpointsSchemaV1UseCase: @@ -13,7 +13,7 @@ class GetModelEndpointsSchemaV1UseCase: def __init__(self, model_endpoint_service: ModelEndpointService): self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute(self, user: User) -> GetModelEndpointsSchemaV1Response: """Execute the use case. diff --git a/server/llm_engine_server/domain/use_cases/streaming_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py similarity index 50% rename from server/llm_engine_server/domain/use_cases/streaming_inference_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py index 1eddf85a..baf68a06 100644 --- a/server/llm_engine_server/domain/use_cases/streaming_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/streaming_inference_use_cases.py @@ -1,20 +1,21 @@ from typing import AsyncIterable -from llm_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.tasks import ( + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, +) +from model_engine_server.domain.entities import ModelEndpointType +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, -) -from llm_engine_server.domain.entities import ModelEndpointType -from llm_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException -from llm_engine_server.domain.services.model_endpoint_service import ModelEndpointService +from model_engine_server.domain.services.model_endpoint_service import ModelEndpointService class CreateStreamingInferenceTaskV1UseCase: @@ -24,10 +25,10 @@ class CreateStreamingInferenceTaskV1UseCase: def __init__(self, model_endpoint_service: ModelEndpointService): self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( - self, user: User, model_endpoint_id: str, request: EndpointPredictV1Request + self, user: User, model_endpoint_id: str, request: SyncEndpointPredictV1Request ) -> AsyncIterable[SyncEndpointPredictV1Response]: """ Runs the use case to create a sync inference task. @@ -61,6 +62,24 @@ async def execute( inference_gateway = ( self.model_endpoint_service.get_streaming_model_endpoint_inference_gateway() ) + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint_id + ) + # Hack: manually resolve dns if istio is present. Since we do not inject istio for multinode, + # empirically we find that without manual dns resolution, requests to the k8s service DNS name fail, + # likely because the requests are getting changed by Istio. A fix is to resolve the service DNS name + # (e.g. model-endpoint-foo.namespace.svc.cluster.local) to the actual IP address of the service + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) return inference_gateway.streaming_predict( - topic=model_endpoint.record.destination, predict_request=request + topic=model_endpoint.record.destination, + predict_request=request, + manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) diff --git a/server/llm_engine_server/domain/use_cases/sync_inference_use_cases.py b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py similarity index 51% rename from server/llm_engine_server/domain/use_cases/sync_inference_use_cases.py rename to model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py index 44fc9c11..a665a846 100644 --- a/server/llm_engine_server/domain/use_cases/sync_inference_use_cases.py +++ b/model-engine/model_engine_server/domain/use_cases/sync_inference_use_cases.py @@ -1,18 +1,19 @@ -from llm_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.tasks import ( + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, +) +from model_engine_server.domain.entities import ModelEndpointType +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, -) -from llm_engine_server.domain.entities import ModelEndpointType -from llm_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException -from llm_engine_server.domain.services.model_endpoint_service import ModelEndpointService +from model_engine_server.domain.services.model_endpoint_service import ModelEndpointService class CreateSyncInferenceTaskV1UseCase: @@ -22,10 +23,10 @@ class CreateSyncInferenceTaskV1UseCase: def __init__(self, model_endpoint_service: ModelEndpointService): self.model_endpoint_service = model_endpoint_service - self.authz_module = ScaleAuthorizationModule() + self.authz_module = LiveAuthorizationModule() async def execute( - self, user: User, model_endpoint_id: str, request: EndpointPredictV1Request + self, user: User, model_endpoint_id: str, request: SyncEndpointPredictV1Request ) -> SyncEndpointPredictV1Response: """ Runs the use case to create a sync inference task. @@ -41,6 +42,7 @@ async def execute( Raises: ObjectNotFoundException: If a model endpoint with the given ID could not be found. ObjectNotAuthorizedException: If the owner does not own the model endpoint. + asyncio.exceptions.TimeoutError: If the task times out. """ model_endpoint = await self.model_endpoint_service.get_model_endpoint( model_endpoint_id=model_endpoint_id @@ -64,6 +66,24 @@ async def execute( ) inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() + autoscaling_metrics_gateway = ( + self.model_endpoint_service.get_inference_autoscaling_metrics_gateway() + ) + await autoscaling_metrics_gateway.emit_inference_autoscaling_metric( + endpoint_id=model_endpoint_id + ) + # Hack: manually resolve dns if istio is present. Since we do not inject istio for multinode, + # empirically we find that without manual dns resolution, requests to the k8s service DNS name fail, + # likely because the requests are getting changed by Istio. A fix is to resolve the service DNS name + # (e.g. model-endpoint-foo.namespace.svc.cluster.local) to the actual IP address of the service + manually_resolve_dns = ( + model_endpoint.infra_state is not None + and model_endpoint.infra_state.resource_state.nodes_per_worker > 1 + and hmi_config.istio_enabled + ) return await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=request + topic=model_endpoint.record.destination, + predict_request=request, + manually_resolve_dns=manually_resolve_dns, + endpoint_name=model_endpoint.record.name, ) diff --git a/model-engine/model_engine_server/domain/use_cases/trigger_use_cases.py b/model-engine/model_engine_server/domain/use_cases/trigger_use_cases.py new file mode 100644 index 00000000..a0bd1769 --- /dev/null +++ b/model-engine/model_engine_server/domain/use_cases/trigger_use_cases.py @@ -0,0 +1,244 @@ +import os + +from croniter import croniter +from model_engine_server.common.dtos.triggers import ( + CreateTriggerV1Request, + CreateTriggerV1Response, + DeleteTriggerV1Response, + GetTriggerV1Response, + ListTriggersV1Response, + UpdateTriggerV1Request, + UpdateTriggerV1Response, +) +from model_engine_server.common.resource_limits import validate_resource_requests +from model_engine_server.common.settings import REQUIRED_ENDPOINT_LABELS +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.core.config import infra_config +from model_engine_server.domain.authorization.live_authorization_module import ( + LiveAuthorizationModule, +) +from model_engine_server.domain.exceptions import ( + CronSyntaxException, + DockerImageNotFoundException, + EndpointLabelsException, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, +) +from model_engine_server.domain.gateways.cron_job_gateway import CronJobGateway +from model_engine_server.domain.repositories import ( + DockerImageBatchJobBundleRepository, + DockerRepository, + TriggerRepository, +) +from model_engine_server.domain.use_cases.model_endpoint_use_cases import validate_labels + +DEFAULT_HOST = f"https://model-engine.{infra_config().dns_host_domain}" + +ALLOWED_CRON_MACROS = set( + [ + "@yearly", + "@annually", + "@monthly", + "@weekly", + "@daily", + "@midnight", + "@hourly", + ] +) + + +def validate_cron(cron: str) -> None: + if len(cron) == 0: + raise CronSyntaxException("Cron expression cannot be empty.") + + if cron not in ALLOWED_CRON_MACROS: + # case on presence of macro identifier + if cron[0] == "@": + raise CronSyntaxException( + f"Unsupported macro supplied: '{cron}'. Please select from the following list, {ALLOWED_CRON_MACROS}." + ) + elif not croniter.is_valid(cron): + raise CronSyntaxException( + f"Invalid Cron syntax: '{cron}'. Please see https://crontab.guru." + ) + + +class CreateTriggerUseCase: + """Use case for creating a Trigger""" + + def __init__( + self, + trigger_repository: TriggerRepository, + cron_job_gateway: CronJobGateway, + docker_image_batch_job_bundle_repository: DockerImageBatchJobBundleRepository, + docker_repository: DockerRepository, + ): + self.trigger_repository = trigger_repository + self.cron_job_gateway = cron_job_gateway + self.docker_image_batch_job_bundle_repository = docker_image_batch_job_bundle_repository + self.docker_repository = docker_repository + self.authz_module = LiveAuthorizationModule() + + async def execute( + self, + user: User, + request: CreateTriggerV1Request, + ) -> CreateTriggerV1Response: + batch_bundle = ( + await self.docker_image_batch_job_bundle_repository.get_docker_image_batch_job_bundle( + request.bundle_id + ) + ) + + if batch_bundle is None: + raise ObjectNotFoundException("The specified batch job bundle could not be found") + if not self.authz_module.check_access_read_owned_entity(user, batch_bundle): + raise ObjectNotAuthorizedException( + f"User {user} does not have permission for the specified batch job bundle" + ) + + if not self.docker_repository.image_exists( + image_tag=batch_bundle.image_tag, repository_name=batch_bundle.image_repository + ): + raise DockerImageNotFoundException( + repository=batch_bundle.image_repository, + tag=batch_bundle.image_tag, + ) # Error if docker image could not be found either + + # check if required resources exist + if None in [batch_bundle.cpus, batch_bundle.memory]: + raise ObjectHasInvalidValueException("Bundle must specify value for cpus and memory") + # validate resource request in cluster also + validate_resource_requests( + bundle=batch_bundle, + cpus=batch_bundle.cpus, + memory=batch_bundle.memory, + storage=batch_bundle.storage, + gpus=batch_bundle.gpus, + gpu_type=batch_bundle.gpu_type, + ) + + if request.default_job_metadata is None: + raise EndpointLabelsException( + f"Missing labels in default_job_metadata. These are all required: {REQUIRED_ENDPOINT_LABELS}" + ) + + validate_labels(request.default_job_metadata) + validate_cron(request.cron_schedule) + + trigger = await self.trigger_repository.create_trigger( + name=request.name, + created_by=user.user_id, + owner=user.team_id, + cron_schedule=request.cron_schedule, + docker_image_batch_job_bundle_id=request.bundle_id, + default_job_config=request.default_job_config, + default_job_metadata=request.default_job_metadata, + ) + + request.default_job_metadata["trigger_id"] = trigger.id + await self.cron_job_gateway.create_cronjob( + request_host=os.getenv("GATEWAY_URL") or DEFAULT_HOST, + trigger_id=trigger.id, + created_by=user.user_id, + owner=user.team_id, + cron_schedule=request.cron_schedule, + docker_image_batch_job_bundle_id=request.bundle_id, + default_job_config=request.default_job_config, + default_job_metadata=request.default_job_metadata, + ) + + return CreateTriggerV1Response(trigger_id=trigger.id) + + +class ListTriggersUseCase: + def __init__(self, trigger_repository: TriggerRepository): + self.trigger_repository = trigger_repository + + async def execute(self, user: User) -> ListTriggersV1Response: + triggers = await self.trigger_repository.list_triggers(owner=user.team_id) + return ListTriggersV1Response( + triggers=[GetTriggerV1Response.from_orm(trigger) for trigger in triggers] + ) + + +class GetTriggerUseCase: + def __init__(self, trigger_repository: TriggerRepository): + self.trigger_repository = trigger_repository + self.authz_module = LiveAuthorizationModule() + + async def execute(self, user: User, trigger_id: str) -> GetTriggerV1Response: + trigger = await self.trigger_repository.get_trigger(trigger_id=trigger_id) + if trigger is None: + raise ObjectNotFoundException + if not self.authz_module.check_access_read_owned_entity(user, trigger): + raise ObjectNotAuthorizedException( + f"User {user} is not authorized for trigger {trigger_id}" + ) + + return GetTriggerV1Response.from_orm(trigger) + + +class UpdateTriggerUseCase: + def __init__( + self, + trigger_repository: TriggerRepository, + cron_job_gateway: CronJobGateway, + ): + self.trigger_repository = trigger_repository + self.cron_job_gateway = cron_job_gateway + self.authz_module = LiveAuthorizationModule() + + async def execute( + self, user: User, trigger_id: str, request: UpdateTriggerV1Request + ) -> UpdateTriggerV1Response: + trigger = await self.trigger_repository.get_trigger(trigger_id=trigger_id) + if trigger is None: + raise ObjectNotFoundException + if not self.authz_module.check_access_read_owned_entity(user, trigger): + raise ObjectNotAuthorizedException( + f"User {user} is not authorized for trigger {trigger_id}" + ) + + success = True + if request.cron_schedule is not None: + validate_cron(request.cron_schedule) + success = await self.trigger_repository.update_trigger( + trigger_id=trigger_id, cron_schedule=request.cron_schedule + ) + + if success: + await self.cron_job_gateway.update_cronjob( + trigger_id=trigger.id, + cron_schedule=request.cron_schedule, + suspend=request.suspend, + ) + + return UpdateTriggerV1Response(success=success) + + +class DeleteTriggerUseCase: + def __init__( + self, + trigger_repository: TriggerRepository, + cron_job_gateway: CronJobGateway, + ): + self.trigger_repository = trigger_repository + self.cron_job_gateway = cron_job_gateway + self.authz_module = LiveAuthorizationModule() + + async def execute(self, user: User, trigger_id: str) -> DeleteTriggerV1Response: + trigger = await self.trigger_repository.get_trigger(trigger_id=trigger_id) + if trigger is None: + raise ObjectNotFoundException + if not self.authz_module.check_access_read_owned_entity(user, trigger): + raise ObjectNotAuthorizedException( + f"User {user} is not authorized for trigger {trigger_id}" + ) + + success = await self.trigger_repository.delete_trigger(trigger_id=trigger_id) + if success: + await self.cron_job_gateway.delete_cronjob(trigger_id=trigger_id) + + return DeleteTriggerV1Response(success=success) diff --git a/server/llm_engine_server/inference/forwarding/__init__.py b/model-engine/model_engine_server/entrypoints/__init__.py similarity index 100% rename from server/llm_engine_server/inference/forwarding/__init__.py rename to model-engine/model_engine_server/entrypoints/__init__.py diff --git a/server/llm_engine_server/entrypoints/init_database.py b/model-engine/model_engine_server/entrypoints/init_database.py similarity index 66% rename from server/llm_engine_server/entrypoints/init_database.py rename to model-engine/model_engine_server/entrypoints/init_database.py index cea6330f..14f7ac77 100644 --- a/server/llm_engine_server/entrypoints/init_database.py +++ b/model-engine/model_engine_server/entrypoints/init_database.py @@ -2,13 +2,13 @@ import os import psycopg2 -from llm_engine_server.db.base import Base -from llm_engine_server.db.models import * +from model_engine_server.db.base import Base, get_engine_url +from model_engine_server.db.models import * from sqlalchemy import create_engine from sqlalchemy.engine import Engine from tenacity import Retrying, stop_after_attempt, wait_exponential -SCHEMAS = ["llm_engine", "model"] +SCHEMAS = ["hosted_model_inference", "model"] def init_database(database_url: str, psycopg_connection): @@ -38,13 +38,16 @@ def init_database_and_engine(database_url) -> Engine: if __name__ == "__main__": url = os.getenv("ML_INFRA_DATABASE_URL") - if url is not None: - for attempt in Retrying( - stop=stop_after_attempt(6), - wait=wait_exponential(), - reraise=True, - ): - with attempt: - init_database_and_engine(url) - - print(f"Successfully initialized database at {url}") + # If we are at this point, we want to init the db. + if url is None: + print("No k8s secret for DB url found, trying AWS secret") + url = get_engine_url(read_only=False, sync=True).url + for attempt in Retrying( + stop=stop_after_attempt(6), + wait=wait_exponential(), + reraise=True, + ): + with attempt: + init_database_and_engine(url) + + print(f"Successfully initialized database at {url}") diff --git a/server/llm_engine_server/entrypoints/init_llm_engine_models.py b/model-engine/model_engine_server/entrypoints/init_spellbook_models.py similarity index 95% rename from server/llm_engine_server/entrypoints/init_llm_engine_models.py rename to model-engine/model_engine_server/entrypoints/init_spellbook_models.py index 80ffc275..46c6a53a 100644 --- a/server/llm_engine_server/entrypoints/init_llm_engine_models.py +++ b/model-engine/model_engine_server/entrypoints/init_spellbook_models.py @@ -2,7 +2,7 @@ from typing import Any, Dict import requests -from llm_engine_server.domain.entities import ModelEndpointType +from launch.api_client.model.model_endpoint_type import ModelEndpointType from tenacity import retry, stop_after_attempt, wait_fixed DEFAULT_NETWORK_TIMEOUT_SEC = 10 @@ -120,7 +120,7 @@ def spellbook_bundle_payload( "flavor": { "flavor": "runnable_image", "repository": "instant-llm", - "tag": f"llm_engine_llm_cuda_image_{git_commit}", + "tag": f"launch_llm_cuda_image_{git_commit}", "command": [ "dumb-init", "--", @@ -147,7 +147,7 @@ def spellbook_endpoint_payload( *, endpoint_name: str, bundle_name: str, - endpoint_type: ModelEndpointType = ModelEndpointType.SYNC, + endpoint_type: ModelEndpointType = "async", min_workers: int = 0, max_workers: int = 1, memory: str = "185Gi", @@ -228,7 +228,7 @@ def create_model_endpoint( return response.json() -def create_llm_engine_deployments(gateway_url: str): +def create_spellbook_deployments(gateway_url: str): for model_name, service_config in SERVICE_CONFIGS.items(): bundle_payload = spellbook_bundle_payload( model_name=model_name, @@ -252,5 +252,4 @@ def create_llm_engine_deployments(gateway_url: str): args = parser.parse_args() ensure_gateway_ready(args.gateway_url) - # TODO: Renable this when we're ready to pre-init models - # create_llm_engine_deployments(args.gateway_url) + create_spellbook_deployments(args.gateway_url) diff --git a/server/llm_engine_server/entrypoints/k8s_cache.py b/model-engine/model_engine_server/entrypoints/k8s_cache.py similarity index 56% rename from server/llm_engine_server/entrypoints/k8s_cache.py rename to model-engine/model_engine_server/entrypoints/k8s_cache.py index 593418c5..98dcd9b3 100644 --- a/server/llm_engine_server/entrypoints/k8s_cache.py +++ b/model-engine/model_engine_server/entrypoints/k8s_cache.py @@ -10,48 +10,56 @@ from kubernetes import config as kube_config from kubernetes.config.config_exception import ConfigException -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.constants import READYZ_FPATH -from llm_engine_server.common.env_vars import CIRCLECI -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.db.base import SessionAsyncNullPool -from llm_engine_server.domain.repositories import DockerRepository -from llm_engine_server.infra.gateways import FakeMonitoringMetricsGateway -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( +from model_engine_server.api.dependencies import get_monitoring_metrics_gateway +from model_engine_server.common.config import hmi_config +from model_engine_server.common.constants import READYZ_FPATH +from model_engine_server.common.env_vars import CIRCLECI +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.db.base import get_session_async_null_pool +from model_engine_server.domain.repositories import DockerRepository +from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( + ASBQueueEndpointResourceDelegate, +) +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) -from llm_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( - FakeSQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( + FakeQueueEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.image_cache_gateway import ImageCacheGateway -from llm_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( +from model_engine_server.infra.gateways.resources.image_cache_gateway import ImageCacheGateway +from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) -from llm_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( - LiveSQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, +) +from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( + SQSQueueEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, +from model_engine_server.infra.repositories import ( + ACRDockerRepository, + ECRDockerRepository, + FakeDockerRepository, ) -from llm_engine_server.infra.repositories import ECRDockerRepository -from llm_engine_server.infra.repositories.db_model_endpoint_record_repository import ( +from model_engine_server.infra.repositories.db_model_endpoint_record_repository import ( DbModelEndpointRecordRepository, ) -from llm_engine_server.infra.repositories.model_endpoint_cache_repository import ( +from model_engine_server.infra.repositories.model_endpoint_cache_repository import ( ModelEndpointCacheRepository, ) -from llm_engine_server.infra.repositories.model_endpoint_record_repository import ( +from model_engine_server.infra.repositories.model_endpoint_record_repository import ( ModelEndpointRecordRepository, ) -from llm_engine_server.infra.repositories.redis_model_endpoint_cache_repository import ( +from model_engine_server.infra.repositories.redis_model_endpoint_cache_repository import ( RedisModelEndpointCacheRepository, ) -from llm_engine_server.infra.services.image_cache_service import ImageCacheService -from llm_engine_server.infra.services.model_endpoint_cache_service import ( +from model_engine_server.infra.services.image_cache_service import ImageCacheService +from model_engine_server.infra.services.model_endpoint_cache_service import ( ModelEndpointCacheWriteService, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) # This is the entrypoint to the k8s cacher try: @@ -91,23 +99,35 @@ async def main(args: Any): logger.info(f"Using cache redis url {redis_url}") cache_repo = RedisModelEndpointCacheRepository(redis_info=redis_url) - monitoring_metrics_gateway = FakeMonitoringMetricsGateway() + monitoring_metrics_gateway = get_monitoring_metrics_gateway() endpoint_record_repo = DbModelEndpointRecordRepository( monitoring_metrics_gateway=monitoring_metrics_gateway, - session=SessionAsyncNullPool, + session=get_session_async_null_pool(), read_only=True, ) - sqs_delegate: SQSEndpointResourceDelegate + + queue_delegate: QueueEndpointResourceDelegate if CIRCLECI: - sqs_delegate = FakeSQSEndpointResourceDelegate() + queue_delegate = FakeQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "azure": + queue_delegate = ASBQueueEndpointResourceDelegate() else: - sqs_delegate = LiveSQSEndpointResourceDelegate( + queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) ) - k8s_resource_manager = LiveEndpointResourceGateway(sqs_delegate=sqs_delegate) + k8s_resource_manager = LiveEndpointResourceGateway( + queue_delegate=queue_delegate, + inference_autoscaling_metrics_gateway=None, + ) image_cache_gateway = ImageCacheGateway() - docker_repo = ECRDockerRepository() + docker_repo: DockerRepository + if CIRCLECI: + docker_repo = FakeDockerRepository() + elif infra_config().docker_repo_prefix.endswith("azurecr.io"): + docker_repo = ACRDockerRepository() + else: + docker_repo = ECRDockerRepository() while True: loop_start = time.time() await loop_iteration( diff --git a/model-engine/model_engine_server/entrypoints/populate_llm_fine_tuning_job_repository.py b/model-engine/model_engine_server/entrypoints/populate_llm_fine_tuning_job_repository.py new file mode 100644 index 00000000..2e0caf29 --- /dev/null +++ b/model-engine/model_engine_server/entrypoints/populate_llm_fine_tuning_job_repository.py @@ -0,0 +1,449 @@ +""" +This script initializes the file backing the LLMFineTuneRepository and adds a test template to it + +FOR TESTING: +To get the bundle id, print the result of calling +`get_or_create_docker_image_batch_job_bundle(CREATE_FINE_TUNE_DI_BATCH_JOB_BUNDLE_REQUEST, users[0])` +from e2e_test_v1.py + +FOR ACTUAL CREATION: +You will need a docker image from the fine-tuning repo. Refer to llm/finetune_pipeline/README.md for instructions. + +""" + +import argparse +import asyncio + +import requests +from model_engine_server.common.config import hmi_config +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate +from model_engine_server.infra.repositories import ( + ABSFileLLMFineTuneRepository, + S3FileLLMFineTuneRepository, +) + +FT_IMAGE_TAG = "00f0edae308d9cd5d9fc24fbd4ee0180e8edc738" + +BUNDLE_NAME_BY_MODEL = { + "7b_or_13b": "fine-tune-upload-safetensors", + "llama_2_34b": "fine-tune-upload-safetensors-34b", + "llama_2_70b": "fine-tune-upload-safetensors-70b", +} + +DEFAULT_7B_MODEL_CONFIG = { + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "latest", + "num_shards": 1, + "quantize": None, + "cpus": 8, + "memory": "24Gi", + "storage": "40Gi", + "gpus": 1, + "gpu_type": "nvidia-ampere-a10", + "min_workers": 0, + "max_workers": 1, + "per_worker": 10, + "endpoint_type": "streaming", +} + +DEFAULT_13B_MODEL_CONFIG = { + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "latest", + "num_shards": 2, + "quantize": None, + "cpus": 16, + "memory": "48Gi", + "storage": "80Gi", + "gpus": 2, + "gpu_type": "nvidia-ampere-a10", + "min_workers": 0, + "max_workers": 1, + "per_worker": 10, + "endpoint_type": "streaming", +} + +# DEFAULT_34B_MODEL_CONFIG defined below because it depends on cloud_provider + +DEFAULT_70B_MODEL_CONFIG = { + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "latest", + "num_shards": 2, + "quantize": None, + "cpus": 20, + "memory": "160Gi", + "storage": "200Gi", + "gpus": 2, + "gpu_type": "nvidia-ampere-a100e", + "min_workers": 0, + "max_workers": 1, + "per_worker": 30, + "endpoint_type": "streaming", +} + + +def create_model_bundle(cloud_provider, url, user, model_type, image_tag): + RESOURCE_REQUESTS_BY_MODEL = { + "7b_or_13b": { + "cpus": 40, + "memory": "160Gi", + "storage": "94Gi", + "gpus": 2 if cloud_provider == "azure" else 4, + "gpu_type": "nvidia-ampere-a10", + }, + "llama_2_34b": { + "cpus": 60, + "memory": "400Gi", + "storage": "300Gi", + "gpus": 4, + "gpu_type": "nvidia-ampere-a100e", + }, + "llama_2_70b": { + "cpus": 80, + "memory": "1000Gi", + "storage": "500Gi", + "gpus": 8, + "gpu_type": "nvidia-ampere-a100e", + }, + } + + name = BUNDLE_NAME_BY_MODEL[model_type] + resource_requests = RESOURCE_REQUESTS_BY_MODEL[model_type] + + response = requests.post( + f"{url}/v1/docker-image-batch-job-bundles", + json={ + "name": name, + "image_repository": "spellbook-finetune", + "image_tag": image_tag, + "command": [ + "dumb-init", + "--", + "ddtrace-run", + "python", + "llm/finetune_pipeline/docker_image_fine_tuning_entrypoint.py", + "--config-file", + "/launch_reserved/config_file.json", + ], + "mount_location": "/launch_reserved/config_file.json", + "resource_requests": resource_requests, + "public": True, + }, + headers={"Content-Type": "application/json"}, + auth=requests.auth.HTTPBasicAuth(user, ""), + ).json() + return response["docker_image_batch_job_bundle_id"] + + +async def main(args): + cloud_provider = args.cloud_provider + url = args.url or f"http://model-engine.{hmi_config.gateway_namespace}.svc.cluster.local" + repository = args.repository or hmi_config.cloud_file_llm_fine_tune_repository + user = args.user or "test-user" + initialize_repository = args.initialize_repository + + if repository.startswith("s3://"): + repo = S3FileLLMFineTuneRepository(file_path=repository) + elif repository.startswith("azure://") or "blob.core.windows.net" in repository: + repo = ABSFileLLMFineTuneRepository(file_path=repository) + else: + raise ValueError(f"LLM fine-tune repository must be S3 or ABS file; got {repository}") + + # Clears the file. Needed the first time we're populating data + if initialize_repository: + await repo.initialize_data() + + lora_7b_or_13b_bun = create_model_bundle(cloud_provider, url, user, "7b_or_13b", FT_IMAGE_TAG) + print(f"lora_7b_or_13b bundle id: {lora_7b_or_13b_bun}") + + lora_llama_2_34b_bun = create_model_bundle( + cloud_provider, url, user, "llama_2_34b", FT_IMAGE_TAG + ) + print(f"lora_34b_bun bundle id: {lora_llama_2_34b_bun}") + + lora_llama_2_70b_bun = create_model_bundle( + cloud_provider, url, user, "llama_2_70b", FT_IMAGE_TAG + ) + print(f"llama_2_70b bundle id: {lora_llama_2_70b_bun}") + + await repo.write_job_template_for_model( + "mpt-7b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "mosaicml/mpt-7b", + "_BASE_MODEL_SHORT": "mpt-7b", + }, + required_params=[], + ), + ) + print("Wrote mpt-7b with lora") + + await repo.write_job_template_for_model( + "mpt-7b-instruct", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "mosaicml/mpt-7b-instruct", + "_BASE_MODEL_SHORT": "mpt-7b-instruct", + }, + required_params=[], + ), + ) + print("Wrote mpt-7b-instruct with lora") + + await repo.write_job_template_for_model( + "llama-7b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "hf-llama-7b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "llama-7b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote llama-7b with lora") + + await repo.write_job_template_for_model( + "llama-2-7b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "hf-llama-2-7b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "llama-2-7b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote llama-2-7b with lora") + + await repo.write_job_template_for_model( + "llama-2-7b-chat", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "hf-llama-2-7b-chat", # == model_name inside of training script + "_BASE_MODEL_SHORT": "llama-2-7b-chat", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote llama-2-7b-chat with lora") + + await repo.write_job_template_for_model( + "llama-2-13b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_13B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "hf-llama-2-13b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "llama-2-13b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote llama-2-13b with lora") + + await repo.write_job_template_for_model( + "llama-2-13b-chat", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_13B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "hf-llama-2-13b-chat", # == model_name inside of training script + "_BASE_MODEL_SHORT": "llama-2-13b-chat", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote llama-2-13b-chat with lora") + + await repo.write_job_template_for_model( + "llama-2-70b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_llama_2_70b_bun, + launch_endpoint_config=DEFAULT_70B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "hf-llama-2-70b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "llama-2-70b", # == create llm endpoint request's model_name + "max_length": 1024, # To prevent OOM on 8xA100e + }, + required_params=[], + ), + ) + print("Wrote llama-2-70b with lora") + + await repo.write_job_template_for_model( + "mistral-7b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "mistralai/mistral-7b-v0.1", # == model_name inside of training script + "_BASE_MODEL_SHORT": "mistral-7b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote mistral-7b with lora") + + await repo.write_job_template_for_model( + "mistral-7b-instruct", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "mistralai/mistral-7b-instruct-v0.1", # == model_name inside of training script + "_BASE_MODEL_SHORT": "mistral-7b-instruct", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote mistral-7b-instruct with lora") + await repo.write_job_template_for_model( + "codellama-7b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "codellama-7b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "codellama-7b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote codellama-7b with lora") + + await repo.write_job_template_for_model( + "codellama-7b-instruct", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_7B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "codellama-7b-instruct", # == model_name inside of training script + "_BASE_MODEL_SHORT": "codellama-7b-instruct", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote codellama-7b-instruct with lora") + + await repo.write_job_template_for_model( + "codellama-13b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_13B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "codellama-13b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "codellama-13b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote codellama-13b with lora") + + await repo.write_job_template_for_model( + "codellama-13b-instruct", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_7b_or_13b_bun, + launch_endpoint_config=DEFAULT_13B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "codellama-13b-instruct", # == model_name inside of training script + "_BASE_MODEL_SHORT": "codellama-13b-instruct", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote codellama-13b-instruct with lora") + + DEFAULT_34B_MODEL_CONFIG = { + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "latest", + "num_shards": 2 if cloud_provider == "azure" else 4, + "quantize": None, + "cpus": 32, + "memory": "80Gi", + "storage": "100Gi", + "gpus": 2 if cloud_provider == "azure" else 4, + "gpu_type": "nvidia-ampere-a10", + "min_workers": 0, + "max_workers": 1, + "per_worker": 10, + "endpoint_type": "streaming", + } + + await repo.write_job_template_for_model( + "codellama-34b", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_llama_2_34b_bun, + launch_endpoint_config=DEFAULT_34B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "codellama-34b", # == model_name inside of training script + "_BASE_MODEL_SHORT": "codellama-34b", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote codellama-34b with lora") + + await repo.write_job_template_for_model( + "codellama-34b-instruct", + "lora", + LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=lora_llama_2_34b_bun, + launch_endpoint_config=DEFAULT_34B_MODEL_CONFIG, + default_hparams={ + "_BASE_MODEL": "codellama-34b-instruct", # == model_name inside of training script + "_BASE_MODEL_SHORT": "codellama-34b-instruct", # == create llm endpoint request's model_name + }, + required_params=[], + ), + ) + print("Wrote codellama-34b-instruct with lora") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process command line arguments.") + parser.add_argument( + "--cloud-provider", + choices=["aws", "azure"], + help="Cloud provider", + required=False, + default="aws", + ) + parser.add_argument("--url", help="Url to the model-engine gateway", required=False) + parser.add_argument( + "--repository", help="Url to the LLM fine-tuning job repository", required=False + ) + parser.add_argument( + "--user", help="User ID to create Docker image batch job bundles with", required=False + ) + parser.add_argument( + "--initialize-repository", action="store_true", required=False, default=False + ) + args = parser.parse_args() + asyncio.run(main(args)) diff --git a/server/llm_engine_server/entrypoints/start_batch_job_orchestration.py b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py similarity index 55% rename from server/llm_engine_server/entrypoints/start_batch_job_orchestration.py rename to model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py index 33863048..de1bd59b 100644 --- a/server/llm_engine_server/entrypoints/start_batch_job_orchestration.py +++ b/model-engine/model_engine_server/entrypoints/start_batch_job_orchestration.py @@ -4,40 +4,48 @@ from datetime import timedelta import aioredis -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.common.env_vars import CIRCLECI -from llm_engine_server.db.base import SessionAsyncNullPool -from llm_engine_server.domain.entities import BatchJobSerializationFormat -from llm_engine_server.infra.gateways import ( +from model_engine_server.api.dependencies import get_monitoring_metrics_gateway +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.env_vars import CIRCLECI +from model_engine_server.core.config import infra_config +from model_engine_server.db.base import get_session_async_null_pool +from model_engine_server.domain.entities import BatchJobSerializationFormat +from model_engine_server.domain.gateways import TaskQueueGateway +from model_engine_server.infra.gateways import ( + ABSFilesystemGateway, + ASBInferenceAutoscalingMetricsGateway, CeleryTaskQueueGateway, - FakeMonitoringMetricsGateway, LiveAsyncModelEndpointInferenceGateway, LiveBatchJobProgressGateway, LiveModelEndpointInfraGateway, LiveModelEndpointsSchemaGateway, LiveStreamingModelEndpointInferenceGateway, LiveSyncModelEndpointInferenceGateway, + RedisInferenceAutoscalingMetricsGateway, S3FilesystemGateway, ) -from llm_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( - FakeSQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( + ASBQueueEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( +from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( + FakeQueueEndpointResourceDelegate, +) +from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( LiveEndpointResourceGateway, ) -from llm_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( - LiveSQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( + SQSQueueEndpointResourceDelegate, ) -from llm_engine_server.infra.repositories import ( +from model_engine_server.infra.repositories import ( DbBatchJobRecordRepository, DbModelEndpointRecordRepository, RedisModelEndpointCacheRepository, ) -from llm_engine_server.infra.services import ( +from model_engine_server.infra.services import ( LiveBatchJobOrchestrationService, LiveModelEndpointService, ) @@ -50,44 +58,69 @@ async def run_batch_job( serialization_format: BatchJobSerializationFormat, timeout_seconds: float, ): - session = SessionAsyncNullPool + session = get_session_async_null_pool() pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) redis = aioredis.Redis(connection_pool=pool) - redis_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS) sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) - monitoring_metrics_gateway = FakeMonitoringMetricsGateway() + servicebus_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SERVICEBUS) + + monitoring_metrics_gateway = get_monitoring_metrics_gateway() model_endpoint_record_repo = DbModelEndpointRecordRepository( - monitoring_metrics_gateway=monitoring_metrics_gateway, - session=session, - read_only=False, + monitoring_metrics_gateway=monitoring_metrics_gateway, session=session, read_only=False ) - sqs_delegate: SQSEndpointResourceDelegate + queue_delegate: QueueEndpointResourceDelegate if CIRCLECI: - sqs_delegate = FakeSQSEndpointResourceDelegate() + queue_delegate = FakeQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "azure": + queue_delegate = ASBQueueEndpointResourceDelegate() else: - sqs_delegate = LiveSQSEndpointResourceDelegate( + queue_delegate = SQSQueueEndpointResourceDelegate( sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) ) - resource_gateway = LiveEndpointResourceGateway(sqs_delegate=sqs_delegate) - model_endpoint_cache_repo = RedisModelEndpointCacheRepository( - redis_client=redis, + inference_autoscaling_metrics_gateway = ( + ASBInferenceAutoscalingMetricsGateway() + if infra_config().cloud_provider == "azure" + else RedisInferenceAutoscalingMetricsGateway(redis_client=redis) + ) + resource_gateway = LiveEndpointResourceGateway( + queue_delegate=queue_delegate, + inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, ) + + inference_task_queue_gateway: TaskQueueGateway + infra_task_queue_gateway: TaskQueueGateway + if infra_config().cloud_provider == "azure": + inference_task_queue_gateway = servicebus_task_queue_gateway + infra_task_queue_gateway = servicebus_task_queue_gateway + else: + inference_task_queue_gateway = sqs_task_queue_gateway + infra_task_queue_gateway = sqs_task_queue_gateway + model_endpoint_infra_gateway = LiveModelEndpointInfraGateway( resource_gateway=resource_gateway, - task_queue_gateway=redis_task_queue_gateway, + task_queue_gateway=infra_task_queue_gateway, + ) + model_endpoint_cache_repo = RedisModelEndpointCacheRepository( + redis_client=redis, ) async_model_endpoint_inference_gateway = LiveAsyncModelEndpointInferenceGateway( - task_queue_gateway=sqs_task_queue_gateway + task_queue_gateway=inference_task_queue_gateway ) streaming_model_endpoint_inference_gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=monitoring_metrics_gateway, use_asyncio=(not CIRCLECI), ) sync_model_endpoint_inference_gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=monitoring_metrics_gateway, use_asyncio=(not CIRCLECI), ) - filesystem_gateway = S3FilesystemGateway() + filesystem_gateway = ( + ABSFilesystemGateway() + if infra_config().cloud_provider == "azure" + else S3FilesystemGateway() + ) model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=filesystem_gateway ) @@ -99,6 +132,8 @@ async def run_batch_job( streaming_model_endpoint_inference_gateway=streaming_model_endpoint_inference_gateway, sync_model_endpoint_inference_gateway=sync_model_endpoint_inference_gateway, model_endpoints_schema_gateway=model_endpoints_schema_gateway, + inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, + can_scale_http_endpoint_from_zero_flag=False, # shouldn't matter since we only use this to create async endpoints ) batch_job_record_repository = DbBatchJobRecordRepository(session=session, read_only=False) batch_job_progress_gateway = LiveBatchJobProgressGateway(filesystem_gateway=filesystem_gateway) @@ -124,10 +159,7 @@ def entrypoint(): parser = argparse.ArgumentParser() parser.add_argument("--job-id", "-j", required=True, help="The ID of the batch job to run.") parser.add_argument( - "--owner", - "-o", - required=True, - help="The ID of the user who owns the batch job.", + "--owner", "-o", required=True, help="The ID of the user who owns the batch job." ) parser.add_argument("--input-path", "-i", required=True, help="The path to the input data.") parser.add_argument( diff --git a/server/llm_engine_server/entrypoints/start_docker_image_batch_job_init_container.py b/model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py similarity index 76% rename from server/llm_engine_server/entrypoints/start_docker_image_batch_job_init_container.py rename to model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py index c552b09a..1c0048be 100644 --- a/server/llm_engine_server/entrypoints/start_docker_image_batch_job_init_container.py +++ b/model-engine/model_engine_server/entrypoints/start_docker_image_batch_job_init_container.py @@ -1,13 +1,13 @@ import argparse import shutil -import llm_engine_server.core.aws.storage_client as storage_client -from llm_engine_server.common.serialization_utils import b64_to_str -from llm_engine_server.core.aws.storage_client import s3_fileobj_exists -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.core.utils.url import parse_attachment_url +import model_engine_server.core.aws.storage_client as storage_client +from model_engine_server.common.serialization_utils import b64_to_str +from model_engine_server.core.aws.storage_client import s3_fileobj_exists +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.utils.url import parse_attachment_url -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) def main(input_local: str, local_file: str, remote_file: str, file_contents_b64encoded: str): @@ -41,9 +41,4 @@ def main(input_local: str, local_file: str, remote_file: str, file_contents_b64e parser.add_argument("--remote-file", type=str) parser.add_argument("--file-contents-b64encoded", type=str) args = parser.parse_args() - main( - args.input_local, - args.local_file, - args.remote_file, - args.file_contents_b64encoded, - ) + main(args.input_local, args.local_file, args.remote_file, args.file_contents_b64encoded) diff --git a/server/llm_engine_server/entrypoints/start_fastapi_server.py b/model-engine/model_engine_server/entrypoints/start_fastapi_server.py similarity index 92% rename from server/llm_engine_server/entrypoints/start_fastapi_server.py rename to model-engine/model_engine_server/entrypoints/start_fastapi_server.py index d120a31b..90271625 100644 --- a/server/llm_engine_server/entrypoints/start_fastapi_server.py +++ b/model-engine/model_engine_server/entrypoints/start_fastapi_server.py @@ -3,6 +3,7 @@ You can do this with `start-fastapi-server`. """ + import argparse import subprocess from typing import List @@ -22,11 +23,11 @@ def start_gunicorn_server(port: int, num_workers: int, debug: bool) -> None: "--keep-alive", "2", "--worker-class", - "llm_engine_server.api.worker.LLMEngineWorker", + "model_engine_server.api.worker.LaunchWorker", "--workers", f"{num_workers}", *additional_args, - "llm_engine_server.api.app:app", + "model_engine_server.api.app:app", ] subprocess.run(command, check=True) diff --git a/server/llm_engine_server/inference/infra/__init__.py b/model-engine/model_engine_server/inference/__init__.py similarity index 100% rename from server/llm_engine_server/inference/infra/__init__.py rename to model-engine/model_engine_server/inference/__init__.py diff --git a/server/llm_engine_server/inference/infra/gateways/__init__.py b/model-engine/model_engine_server/inference/async_inference/__init__.py similarity index 100% rename from server/llm_engine_server/inference/infra/gateways/__init__.py rename to model-engine/model_engine_server/inference/async_inference/__init__.py diff --git a/server/llm_engine_server/inference/async_inference/celery.py b/model-engine/model_engine_server/inference/async_inference/celery.py similarity index 64% rename from server/llm_engine_server/inference/async_inference/celery.py rename to model-engine/model_engine_server/inference/async_inference/celery.py index 80ba64a0..3ea5db6d 100644 --- a/server/llm_engine_server/inference/async_inference/celery.py +++ b/model-engine/model_engine_server/inference/async_inference/celery.py @@ -1,23 +1,24 @@ import os -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.core.celery import TaskVisibility, celery_app -from llm_engine_server.inference.common import unset_sensitive_envvars +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.core.celery import TaskVisibility, celery_app +from model_engine_server.inference.common import unset_sensitive_envvars unset_sensitive_envvars() broker_type_str = os.getenv("BROKER_TYPE") broker_type = BrokerType(broker_type_str) s3_bucket: str = os.environ.get("CELERY_S3_BUCKET") # type: ignore celery_kwargs = dict( - name="llm_engine_server.inference.async_inference", - modules=["llm_engine_server.inference.async_inference.tasks"], + name="model_engine_server.inference.async_inference", + modules=["model_engine_server.inference.async_inference.tasks"], aws_role=os.environ["AWS_PROFILE"], s3_bucket=s3_bucket, # s3_base_path = TODO get from env var/config task_reject_on_worker_lost=False, worker_proc_alive_timeout=1500, broker_type=broker_type_str, - task_visibility=TaskVisibility.VISIBILITY_24H, # We're using SQS so this only changes task_time_limit + task_visibility=TaskVisibility.VISIBILITY_24H, + # We're using SQS so this only changes task_time_limit ) if broker_type == BrokerType.SQS: queue_name = os.getenv("SQS_QUEUE_NAME") @@ -26,7 +27,6 @@ dict(broker_transport_options={"predefined_queues": {queue_name: {"url": queue_url}}}) ) - async_inference_service = celery_app(**celery_kwargs) # type: ignore if __name__ == "__main__": diff --git a/server/llm_engine_server/inference/async_inference/tasks.py b/model-engine/model_engine_server/inference/async_inference/tasks.py similarity index 57% rename from server/llm_engine_server/inference/async_inference/tasks.py rename to model-engine/model_engine_server/inference/async_inference/tasks.py index 999cc270..69f9c9d0 100644 --- a/server/llm_engine_server/inference/async_inference/tasks.py +++ b/model-engine/model_engine_server/inference/async_inference/tasks.py @@ -3,24 +3,17 @@ from celery import Task from celery.signals import worker_process_init -from llm_engine_server.common.constants import READYZ_FPATH -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.common.serialization_utils import str_to_bool -from llm_engine_server.core.loggers import make_logger -from llm_engine_server.core.utils.timer import timer -from llm_engine_server.domain.entities import ModelEndpointConfig -from llm_engine_server.inference.async_inference.celery import async_inference_service -from llm_engine_server.inference.common import ( - get_endpoint_config, - load_predict_fn_or_cls, - run_predict, -) -from llm_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( - DatadogInferenceMonitoringMetricsGateway, -) -from llm_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler - -logger = make_logger(__name__) +from model_engine_server.common.constants import READYZ_FPATH +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.common.serialization_utils import str_to_bool +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.utils.timer import timer +from model_engine_server.domain.entities import ModelEndpointConfig +from model_engine_server.inference.async_inference.celery import async_inference_service +from model_engine_server.inference.common import load_predict_fn_or_cls, run_predict +from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler + +logger = make_logger(logger_name()) # This should be safe as long as the celery workers are separate processes # (or we're using pool=solo) so they're not shared between threads @@ -35,16 +28,6 @@ def init_worker_global(): with timer(logger=logger, name="load_predict_fn_or_cls"): predict_fn_or_cls = load_predict_fn_or_cls() - endpoint_config = get_endpoint_config() - hooks = PostInferenceHooksHandler( - endpoint_name=endpoint_config.endpoint_name, - bundle_name=endpoint_config.bundle_name, - post_inference_hooks=endpoint_config.post_inference_hooks, - user_id=endpoint_config.user_id, - default_callback_url=endpoint_config.default_callback_url, - default_callback_auth=endpoint_config.default_callback_auth, - monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), - ) # k8s health check with open(READYZ_FPATH, "w") as f: f.write("READY") @@ -86,12 +69,11 @@ def predict(self, request_params, return_pickled): request_params_pydantic = EndpointPredictV1Request.parse_obj(request_params) return run_predict(predict_fn_or_cls, request_params_pydantic) # type: ignore - def on_success(self, retval, task_id, args, kwargs): - request_params = args[0] - request_params_pydantic = EndpointPredictV1Request.parse_obj(request_params) - hooks.handle(request_params_pydantic, retval, task_id) # type: ignore - -@async_inference_service.task(base=InferenceTask) +@async_inference_service.task( + base=InferenceTask, + # For legacy reasons, we need to use the old name. + name="hosted_model_inference.inference.async_inference.tasks.predict", +) def predict(request_params: Dict[str, Any], return_pickled=True): return predict.predict(request_params, return_pickled) diff --git a/server/llm_engine_server/inference/async_inference/vpa.yaml b/model-engine/model_engine_server/inference/async_inference/vpa.yaml similarity index 100% rename from server/llm_engine_server/inference/async_inference/vpa.yaml rename to model-engine/model_engine_server/inference/async_inference/vpa.yaml diff --git a/server/llm_engine_server/inference/base.Dockerfile b/model-engine/model_engine_server/inference/base.Dockerfile similarity index 69% rename from server/llm_engine_server/inference/base.Dockerfile rename to model-engine/model_engine_server/inference/base.Dockerfile index ab0f9310..88a7f7bf 100644 --- a/server/llm_engine_server/inference/base.Dockerfile +++ b/model-engine/model_engine_server/inference/base.Dockerfile @@ -3,6 +3,8 @@ FROM ${BASE_IMAGE} WORKDIR /app +RUN rm -rf /var/lib/apt/lists/* + # Install basic packages. RUN apt-get update && apt-get install -y \ apt-utils \ @@ -22,9 +24,9 @@ RUN apt-get update && apt-get install -y \ build-essential \ && rm -rf /var/lib/apt/lists/* -COPY --chown=root llm_engine /app/llm_engine -WORKDIR /app/llm_engine +COPY --chown=root model-engine /app/model-engine +WORKDIR /app/model-engine RUN pip install -e . WORKDIR /app -RUN pip install -r /app/llm_engine/llm_engine/inference/requirements_base.txt +RUN pip install -r /app/model-engine/model_engine_server/inference/requirements_base.txt diff --git a/model-engine/model_engine_server/inference/batch_inference/README.md b/model-engine/model_engine_server/inference/batch_inference/README.md new file mode 100644 index 00000000..0c380633 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/README.md @@ -0,0 +1,3 @@ +# Notes + +We will merge this with inference/vllm. In the meantime, you can build the batch image via inference/vllm/build_and_publish_image.sh \ No newline at end of file diff --git a/server/llm_engine_server/inference/sync_inference/__init__.py b/model-engine/model_engine_server/inference/batch_inference/__init__.py similarity index 100% rename from server/llm_engine_server/inference/sync_inference/__init__.py rename to model-engine/model_engine_server/inference/batch_inference/__init__.py diff --git a/model-engine/model_engine_server/inference/batch_inference/dto.py b/model-engine/model_engine_server/inference/batch_inference/dto.py new file mode 100644 index 00000000..f46f62a1 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/dto.py @@ -0,0 +1,180 @@ +# This is a copy of model_engine_server.common.dtos.llms.batch_completion.py +# This is done to decouple the pydantic requirements since vllm requires pydantic >2 +# while model engine is on 1.x +from enum import Enum +from typing import Dict, List, Optional + +from pydantic import BaseModel, ConfigDict, Field + + +class TokenOutput(BaseModel): + token: str + log_prob: float + + +class CompletionOutput(BaseModel): + text: str + num_prompt_tokens: int + num_completion_tokens: int + tokens: Optional[List[TokenOutput]] = None + + +class CreateBatchCompletionsRequestContent(BaseModel): + prompts: List[str] + max_new_tokens: int + temperature: float = Field(ge=0.0, le=1.0) + """ + Temperature of the sampling. Setting to 0 equals to greedy sampling. + """ + stop_sequences: Optional[List[str]] = None + """ + List of sequences to stop the completion at. + """ + return_token_log_probs: Optional[bool] = False + """ + Whether to return the log probabilities of the tokens. + """ + presence_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on whether they appear in the text so far. 0.0 means no penalty + """ + frequency_penalty: Optional[float] = Field(default=None, ge=0.0, le=2.0) + """ + Only supported in vllm, lightllm + Penalize new tokens based on their existing frequency in the text so far. 0.0 means no penalty + """ + top_k: Optional[int] = Field(default=None, ge=-1) + """ + Controls the number of top tokens to consider. -1 means consider all tokens. + """ + top_p: Optional[float] = Field(default=None, gt=0.0, le=1.0) + """ + Controls the cumulative probability of the top tokens to consider. 1.0 means consider all tokens. + """ + skip_special_tokens: Optional[bool] = True + """ + Whether to skip special tokens in the output. + """ + + +class Quantization(str, Enum): + BITSANDBYTES = "bitsandbytes" + AWQ = "awq" + + +class CreateBatchCompletionsModelConfig(BaseModel): + model: str + checkpoint_path: Optional[str] = None + """ + Path to the checkpoint to load the model from. + """ + num_shards: Optional[int] = 1 + """ + Suggested number of shards to distribute the model. When not specified, will infer the number of shards based on model config. + System may decide to use a different number than the given value. + """ + quantize: Optional[Quantization] = None + """ + Whether to quantize the model. + """ + seed: Optional[int] = None + """ + Random seed for the model. + """ + + +class ToolConfig(BaseModel): + """ + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + """ + + name: str + """ + Name of the tool to use for the batch inference. + """ + max_iterations: Optional[int] = 10 + """ + Maximum number of iterations to run the tool. + """ + execution_timeout_seconds: Optional[int] = 60 + """ + Maximum runtime of the tool in seconds. + """ + should_retry_on_error: Optional[bool] = True + """ + Whether to retry the tool on error. + """ + + +class CreateBatchCompletionsRequest(BaseModel): + """ + Request object for batch completions. + """ + + input_data_path: Optional[str] = None + output_data_path: str + """ + Path to the output file. The output file will be a JSON file of type List[CompletionOutput]. + """ + labels: Dict[str, str] = Field( + default={}, description="Labels to attach to the batch inference job." + ) + content: Optional[CreateBatchCompletionsRequestContent] = None + """ + Either `input_data_path` or `content` needs to be provided. + When input_data_path is provided, the input file should be a JSON file of type BatchCompletionsRequestContent. + """ + + data_parallelism: int = Field(default=1, ge=1, le=64) + """ + Number of replicas to run the batch inference. More replicas are slower to schedule but faster to inference. + """ + max_runtime_sec: Optional[int] = Field(default=24 * 3600, ge=1, le=2 * 24 * 3600) + """ + Maximum runtime of the batch inference in seconds. Default to one day. + """ + tool_config: Optional[ToolConfig] = None + """ + Configuration for tool use. + NOTE: this config is highly experimental and signature will change significantly in future iterations. + """ + + max_context_length: Optional[int] = Field( + default=None, + ge=1, + description="Maximum context length to use for the model. Defaults to the max allowed by the model", + ) + + +class VLLMEngineAdditionalArgs(BaseModel): + max_gpu_memory_utilization: Optional[float] = Field( + default=0.9, + le=1.0, + description="Maximum GPU memory utilization for the model. Default to 90%.", + ) + + attention_backend: Optional[str] = Field( + default=None, + description="Attention backend to use for vLLM. Default to None.", + ) + + +class CreateBatchCompletionsEngineRequest(CreateBatchCompletionsRequest, VLLMEngineAdditionalArgs): + """ + Internal model for representing request to the inference framework. This contains additional fields that we want + hidden from the DTO exposed to the client. + """ + + model_config = ConfigDict(populate_by_name=True, protected_namespaces=()) + + model_cfg: CreateBatchCompletionsModelConfig = Field(alias="model_config") + """ + Model configuration for the batch inference. Hardware configurations are inferred. + + We rename model_config from api to model_cfg in engine since engine uses pydantic v2 which + reserves model_config as a keyword. + + We alias `model_config` for deserialization for backwards compatibility. + """ diff --git a/model-engine/model_engine_server/inference/batch_inference/examples/generate_tool_sample_data.py b/model-engine/model_engine_server/inference/batch_inference/examples/generate_tool_sample_data.py new file mode 100644 index 00000000..d60a76d4 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/examples/generate_tool_sample_data.py @@ -0,0 +1,79 @@ +import json + +COMPLETION_PROMPT1 = """\ +FYI: you can write code like this: +```python +import math +print(math.sqrt(2)) +``` +1.41... +>>> + +For reference, the third digit of 4.32 is 2. Also, use "Final Answer: X" to indicate your final answer. + +### Problem: + +What is the 4th digit of pi? + +### Answer: +```python +import math +print(math.pi) +``` +3.141592653589793 +>>> + +Final Answer: 1 + +### Problem: + +What is the 4th digit of the square root of 2? + +### Answer: +""" + +COMPLETION_PROMPT2 = """\ +FYI: you can write code like this: +```python +import math +print(math.sqrt(2)) +``` +1.41... +>>> + +For reference, the third digit of 4.32 is 2. Also, use "Final Answer: X" to indicate your final answer. + +### Problem: + +What is the 4th digit of pi? + +### Answer: +```python +import math +print(math.pi) +``` +3.141592653589793 +>>> + +Final Answer: 1 + +### Problem: + +What is the 5th digit of the square root of 2? + +### Answer: +""" + +data = { + "prompts": [ + COMPLETION_PROMPT1, + COMPLETION_PROMPT2, + "what is deep learning", + ], + "max_new_tokens": 100, + "temperature": 0.0, + "return_token_log_probs": True, + "stop_sequences": ["", "\n### Problem:\n", ">>>\n"], +} + +json.dump(data, open("sample_data_tool.json", "w")) diff --git a/model-engine/model_engine_server/inference/batch_inference/examples/sample_config.json b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config.json new file mode 100644 index 00000000..8944e66c --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config.json @@ -0,0 +1,14 @@ +{ + "input_data_path": "./examples/sample_data.json", + "output_data_path": "./examples/sample_output.json", + "model_config": { + "model": "mixtral-8x7b-instruct-v0.1", + "checkpoint_path": "my_path", + "num_shards": 2, + "labels": { + "team": "my_team" + } + }, + "attention_backend": "FLASHINFER", + "data_parallelism": 1 +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_gemma.json b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_gemma.json new file mode 100644 index 00000000..e988c2f9 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_gemma.json @@ -0,0 +1,14 @@ +{ + "input_data_path": "./examples/sample_data.json", + "output_data_path": "./examples/sample_output.json", + "model_config": { + "model": "gemma-2-2b-it", + "checkpoint_path": "my_path", + "num_shards": 1, + "labels": { + "team": "my_team" + } + }, + "attention_backend": "FLASHINFER", + "data_parallelism": 1 +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_mixtral.json b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_mixtral.json new file mode 100644 index 00000000..2c5fcc97 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_mixtral.json @@ -0,0 +1,13 @@ +{ + "input_data_path": "./examples/sample_data.json", + "output_data_path": "./examples/sample_output.json", + "model_config": { + "model": "mixtral-8x7b-instruct-v0.1", + "checkpoint_path": "my_path", + "num_shards": 2, + "labels": { + "team": "my_team" + } + }, + "data_parallelism": 1 +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_tool.json b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_tool.json new file mode 100644 index 00000000..457131d7 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/examples/sample_config_tool.json @@ -0,0 +1,16 @@ +{ + "input_data_path": "./sample_data_tool.json", + "output_data_path": "./sample_output_tool.json", + "model_config": { + "model": "gemma-2-2b-it", + "checkpoint_path": "/workspace/model_files/gemma-2-2b-it", + "num_shards": 1, + "labels": { + "team": "my_team" + } + }, + "data_parallelism": 2, + "tool_config": { + "name": "code_evaluator" + } +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/examples/sample_data.json b/model-engine/model_engine_server/inference/batch_inference/examples/sample_data.json new file mode 100644 index 00000000..d8fa3a68 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/examples/sample_data.json @@ -0,0 +1,9 @@ +{ + "prompts": [ + "san francisco is", + "deep learning is" + ], + "max_new_tokens": 100, + "temperature": 0.0, + "return_token_log_probs": true +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/examples/sample_data_tool.json b/model-engine/model_engine_server/inference/batch_inference/examples/sample_data_tool.json new file mode 100644 index 00000000..f529eca4 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/examples/sample_data_tool.json @@ -0,0 +1,15 @@ +{ + "prompts": [ + "FYI: you can write code like this: \n```python\nimport math\nprint(math.sqrt(2))\n```\n1.41...\n>>>\n\nFor reference, the third digit of 4.32 is 2. Also, use \"Final Answer: X\" to indicate your final answer.\n\n### Problem:\n\nWhat is the 4th digit of pi?\n\n### Answer:\n```python\nimport math\nprint(math.pi)\n```\n3.141592653589793\n>>>\n\nFinal Answer: 1\n\n### Problem:\n\nWhat is the 4th digit of the square root of 2?\n\n### Answer: \n", + "FYI: you can write code like this: \n```python\nimport math\nprint(math.sqrt(2))\n```\n1.41...\n>>>\n\nFor reference, the third digit of 4.32 is 2. Also, use \"Final Answer: X\" to indicate your final answer.\n\n### Problem:\n\nWhat is the 4th digit of pi?\n\n### Answer:\n```python\nimport math\nprint(math.pi)\n```\n3.141592653589793\n>>>\n\nFinal Answer: 1\n\n### Problem:\n\nWhat is the 5th digit of the square root of 2?\n\n### Answer: \n", + "what is deep learning" + ], + "max_new_tokens": 100, + "temperature": 0.0, + "return_token_log_probs": true, + "stop_sequences": [ + "", + "\n### Problem:\n", + ">>>\n" + ] +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/requirements.txt b/model-engine/model_engine_server/inference/batch_inference/requirements.txt new file mode 100644 index 00000000..89413c72 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/requirements.txt @@ -0,0 +1,8 @@ +vllm==0.5.3.post1 +pydantic>=2 +boto3==1.34.15 +smart-open==6.4.0 +ddtrace==2.4.0 +docker==7.0.0 +func-timeout==4.3.5 +datadog==0.49.1 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py new file mode 100644 index 00000000..d4863391 --- /dev/null +++ b/model-engine/model_engine_server/inference/batch_inference/vllm_batch.py @@ -0,0 +1,539 @@ +import argparse +import asyncio +import json +import multiprocessing +import os +import subprocess +import sys +import time +import uuid +from multiprocessing.pool import ThreadPool +from typing import List, Optional, Type +from urllib.parse import urlparse + +import boto3 +import smart_open +from func_timeout import FunctionTimedOut, func_set_timeout +from model_engine_server.inference.batch_inference.dto import ( + CompletionOutput, + CreateBatchCompletionsEngineRequest, + CreateBatchCompletionsRequestContent, + TokenOutput, + ToolConfig, +) +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( + DatadogInferenceMonitoringMetricsGateway, +) +from model_engine_server.inference.tool_completion.tools import TOOL_MAP, BaseTool, Tools, tokenizer +from tqdm import tqdm + +CONFIG_FILE = os.getenv("CONFIG_FILE") +AWS_REGION = os.getenv("AWS_REGION", "us-west-2") +MODEL_WEIGHTS_FOLDER = os.getenv("MODEL_WEIGHTS_FOLDER", "./model_weights") + +SKIP_AWS_PROFILE_SET = os.getenv("SKIP_AWS_PROFILE_SET", "false").lower() == "true" +if not SKIP_AWS_PROFILE_SET: + os.environ["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") + + +def get_cpu_cores_in_container(): + cpu_count = multiprocessing.cpu_count() + try: + with open("/sys/fs/cgroup/cpu/cpu.cfs_quota_us") as fp: + cfs_quota_us = int(fp.read()) + with open("/sys/fs/cgroup/cpu/cpu.cfs_period_us") as fp: + cfs_period_us = int(fp.read()) + if cfs_quota_us != -1: + cpu_count = cfs_quota_us // cfs_period_us + except FileNotFoundError: + pass + return cpu_count + + +CPU_COUNT = get_cpu_cores_in_container() + + +def get_s3_client(): + session = boto3.Session(profile_name=os.getenv("S3_WRITE_AWS_PROFILE")) + return session.client("s3", region_name=AWS_REGION) + + +def download_model(checkpoint_path, final_weights_folder): + s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 --include '*.model' --include '*.json' --include '*.bin' --include '*.safetensors' --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {final_weights_folder}" + env = os.environ.copy() + env["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") + # Need to override these env vars so s5cmd uses AWS_PROFILE + env["AWS_ROLE_ARN"] = "" + env["AWS_WEB_IDENTITY_TOKEN_FILE"] = "" + process = subprocess.Popen( + s5cmd, + shell=True, # nosemgrep + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + env=env, + ) + for line in process.stdout: + print(line, flush=True) + + process.wait() + + if process.returncode != 0: + stderr_lines = [] + for line in iter(process.stderr.readline, ""): + stderr_lines.append(line.strip()) + + print(f"Error downloading model weights: {stderr_lines}", flush=True) + + +def file_exists(path): + try: + with smart_open.open(path, "r"): + return True + except Exception as exc: + print(f"Error checking if file exists: {exc}") + return False + + +def parse_s3_url(s3_url): + parsed_url = urlparse(s3_url) + + if parsed_url.scheme != "s3": + raise ValueError(f'The URL scheme is not "s3": {s3_url}') + + bucket = parsed_url.netloc + key = parsed_url.path.lstrip("/") + + return bucket, key + + +def wait_for_all_chunks(request): + # Max wait time is controlled by the batch job timeout + while True: + print("Waiting for all chunks to be written...") + all_chunks_exist = True + for i in range(request.data_parallelism): + chunk_file = f"{request.output_data_path}.{i}" + if not file_exists(chunk_file): + print(f"Chunk {chunk_file} does not exist yet") + all_chunks_exist = False + break + if all_chunks_exist: + break + time.sleep(5) + print("All chunks written") + + +def combine_all_chunks(request): + print("Combining chunks...") + with smart_open.open(request.output_data_path, "w") as f: + f.write("[") + for i in range(request.data_parallelism): + if i > 0: + f.write(",") + chunk_file = f"{request.output_data_path}.{i}" + with smart_open.open(chunk_file, "r") as chunk_f: + chunk_data = chunk_f.read() + f.write(chunk_data[1:-1]) # Remove leading and trailing brackets + f.write("]") + print("Chunks combined") + + +def delete_s3_chunks(request): + print("Deleting S3 chunks...") + for i in range(request.data_parallelism): + chunk_file = f"{request.output_data_path}.{i}" + bucket, key = parse_s3_url(chunk_file) + get_s3_client().delete_object(Bucket=bucket, Key=key) + print("Chunks deleted") + + +def random_uuid() -> str: + return str(uuid.uuid4().hex) + + +def get_vllm_engine(model: str, request: CreateBatchCompletionsEngineRequest): + from vllm import AsyncEngineArgs, AsyncLLMEngine + + engine_args = AsyncEngineArgs( + model=model, + quantization=request.model_cfg.quantize, + tensor_parallel_size=request.model_cfg.num_shards, + seed=request.model_cfg.seed or 0, + disable_log_requests=True, + gpu_memory_utilization=request.max_gpu_memory_utilization or 0.9, + max_model_len=request.max_context_length, + ) + + llm = AsyncLLMEngine.from_engine_args(engine_args) + return llm + + +async def generate_with_tool( + llm, + tool_config: ToolConfig, + content: CreateBatchCompletionsRequestContent, + prompts, + tool: Type[BaseTool], + is_finetuned: bool, + model: str, +): + class IterativeGeneration: + def __init__(self, prompt, max_new_tokens): + self.generated_text = "" + self.num_prompt_tokens = 0 + self.remaining_tokens = max_new_tokens + self.token_logits = [] + self.tool_exception = None + self.prompt = prompt + self.completed = False + + def __repr__(self) -> str: + return f"generated_text: {self.generated_text}, num_prompt_tokens: {self.num_prompt_tokens}, remaining_tokens: {self.remaining_tokens}, tool_exception: {self.tool_exception}, prompt: {self.prompt}, completed: {self.completed}" + + num_iters = 0 + generations = [IterativeGeneration(prompt, content.max_new_tokens) for prompt in prompts] + max_iterations = tool_config.max_iterations or 10 + stop_sequences = content.stop_sequences or [] + stop_sequences.append(tool.tool_context_end) + + while num_iters < max_iterations: + num_iters += 1 + + iter_prompts = [ + (gen.prompt + gen.generated_text, idx) + for idx, gen in enumerate(generations) + if not gen.completed + ] + + if not iter_prompts: + break + + bar = tqdm( + total=len(iter_prompts), + desc=f"Generating outputs, iteration {num_iters}", + file=sys.stdout, + ) + + outputs = await generate_with_vllm( + llm, + [generations[iter[1]].remaining_tokens for iter in iter_prompts], + content.temperature, + content.stop_sequences, + content.return_token_log_probs, + content.presence_penalty, + content.frequency_penalty, + content.top_k, + content.top_p, + content.skip_special_tokens, + [iter[0] for iter in iter_prompts], + bar, + use_tool=True, + is_finetuned=is_finetuned, + model=model, + ) + + bar = tqdm( + total=len(iter_prompts), + desc=f"Running tools, iteration {num_iters}", + file=sys.stdout, + ) + + def tool_func(i): + bar.update(1) + response = outputs[i] + gen_item = generations[iter_prompts[i][1]] + new_text = response.text + + if content.return_token_log_probs: + gen_item.token_logits += response.tokens + + if not gen_item.num_prompt_tokens: + gen_item.num_prompt_tokens = response.num_prompt_tokens + + # break the loop if generation is complete even if remaining_tokens>0 + if len(new_text) == 0: + gen_item.completed = True + return + + # To-do write tools to receive response object itself rather than the text + try: + # We need to pass the tool/text to a function that times out if the python code can't execute + @func_set_timeout(tool_config.execution_timeout_seconds) + def tool_func(text: str, past_context: Optional[str]): + return tool()(text, past_context) + + past_context = ( + gen_item.generated_text if tool_config.should_retry_on_error else None + ) + new_text, num_tool_output_tokens = tool_func(new_text, past_context) + + except (Exception, FunctionTimedOut) as e: + # If the tool failed, we should add the error message to the generated text and keep going. It should be added right after the + # tool call token and concluded with the tool_context_end_token. + new_text_split = new_text.rsplit(tool.tool_call_token, 1) + + # We can guarantee this because the tool is not called if it doesn't have the tool call token + # We still want to replace what the LLM thinks the output should be.. + added_text = str(e) + tool.tool_context_end + subtracted_text = new_text_split[1] + + new_text = f"{new_text_split[0]}{tool.tool_call_token}{e}{tool.tool_context_end}" + + # Now let's add the additional tokens + num_tool_output_tokens = min( + len(tokenizer(added_text).input_ids) + - len(tokenizer(subtracted_text).input_ids), + 0, + ) + + # Also, define the tool exception here so we can raise it later + gen_item.tool_exception = e + + num_completion_tokens = response.num_completion_tokens + + gen_item.remaining_tokens -= num_completion_tokens + gen_item.remaining_tokens -= num_tool_output_tokens + gen_item.generated_text += new_text + + # If we didn't just execute a tool, we're done + if ( + not gen_item.generated_text.endswith(tool.tool_context_end) + or gen_item.remaining_tokens <= 0 + ): + gen_item.completed = True + + pool = ThreadPool(CPU_COUNT) + pool.map(tool_func, range(len(iter_prompts))) + + results = [ + CompletionOutput( + text=gen_item.generated_text, + num_prompt_tokens=gen_item.num_prompt_tokens, + num_completion_tokens=content.max_new_tokens - gen_item.remaining_tokens, + tokens=gen_item.token_logits if content.return_token_log_probs else None, + ) + for gen_item in generations + ] + + return results + + +async def batch_inference(config_file_data: Optional[str]): + job_index = int(os.getenv("JOB_COMPLETION_INDEX", 0)) + + if config_file_data is None: + if CONFIG_FILE is None or not os.path.exists(CONFIG_FILE): + raise FileNotFoundError(f"Config file {CONFIG_FILE} not found") + with open(CONFIG_FILE, "r") as f: + config_file_data = f.read() + + request = CreateBatchCompletionsEngineRequest.model_validate_json(config_file_data) + + if request.attention_backend is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = request.attention_backend + + if request.model_cfg.checkpoint_path is not None: + download_model(request.model_cfg.checkpoint_path, MODEL_WEIGHTS_FOLDER) + + content = request.content + if content is None: + with smart_open.open(request.input_data_path, "r") as f: + content = CreateBatchCompletionsRequestContent.model_validate_json(f.read()) + + model = MODEL_WEIGHTS_FOLDER if request.model_cfg.checkpoint_path else request.model_cfg.model + is_finetuned = request.model_cfg.checkpoint_path is not None + + llm = get_vllm_engine(model, request) + + prompts = [] + prompts_per_pod = len(content.prompts) // request.data_parallelism + if job_index == request.data_parallelism - 1: + for prompt in content.prompts[prompts_per_pod * job_index :]: + prompts.append(prompt) + else: + for prompt in content.prompts[ + prompts_per_pod * job_index : prompts_per_pod * (job_index + 1) + ]: + prompts.append(prompt) + + if request.tool_config is not None: + tool_enum = Tools(request.tool_config.name) + tool = TOOL_MAP[tool_enum] + outputs = await generate_with_tool( + llm, + request.tool_config, + content, + prompts, + tool, + is_finetuned, + request.model_cfg.model, + ) + else: + bar = tqdm(total=len(prompts), desc="Processed prompts") + + outputs = await generate_with_vllm( + llm, + [content.max_new_tokens] * len(prompts), + content.temperature, + content.stop_sequences, + content.return_token_log_probs, + content.presence_penalty, + content.frequency_penalty, + content.top_k, + content.top_p, + content.skip_special_tokens, + prompts, + bar, + use_tool=False, + is_finetuned=is_finetuned, + model=request.model_cfg.model, + ) + + bar.close() + + output_dicts = [output.dict() for output in outputs] + + if request.data_parallelism == 1: + with smart_open.open(request.output_data_path, "w") as f: + f.write(json.dumps(output_dicts)) + else: + chunk_file = f"{request.output_data_path}.{job_index}" + with smart_open.open(chunk_file, "w") as f: + f.write(json.dumps(output_dicts)) + if job_index == 0: + wait_for_all_chunks(request) + combine_all_chunks(request) + if request.output_data_path.startswith("s3://"): + delete_s3_chunks(request) + + +async def generate_with_vllm( + engine, + max_new_tokens, + temperature, + stop_sequences, + return_token_log_probs, + presence_penalty, + frequency_penalty, + top_k, + top_p, + skip_special_tokens, + prompts, + bar, + use_tool, + is_finetuned, + model, +) -> List[CompletionOutput]: # pragma: no cover + from vllm import SamplingParams + + metrics_gateway = DatadogInferenceMonitoringMetricsGateway() + + # Add the requests to the engine. + results_generators = [] + for idx, prompt in enumerate(prompts): + request_id = random_uuid() + sampling_params = SamplingParams( + max_tokens=max_new_tokens[idx], + temperature=temperature, + stop=stop_sequences, + logprobs=1 if return_token_log_probs else None, + presence_penalty=presence_penalty or 0.0, + frequency_penalty=frequency_penalty or 0.0, + top_k=top_k or -1, + top_p=top_p or 1.0, + skip_special_tokens=(skip_special_tokens if skip_special_tokens is not None else True), + ) + results_generator = await engine.add_request( + request_id, prompt, sampling_params, time.monotonic(), None + ) + results_generators.append(results_generator) + + outputs = [] + for generator in results_generators: + tokens = [] + async for request_output in generator: + if request_output.finished: + bar.update(1) + + if return_token_log_probs: + output = request_output.outputs[0] + log_probs = output.logprobs[-1] if return_token_log_probs else None + token_id = output.token_ids[-1] + tokens.append( + TokenOutput( + token=log_probs[token_id].decoded_token, + log_prob=log_probs[token_id].logprob, + ) + ) + + num_prompt_tokens = len(request_output.prompt_token_ids) + num_completion_tokens = len(request_output.outputs[0].token_ids) + + output = CompletionOutput( + text=request_output.outputs[0].text, + num_prompt_tokens=num_prompt_tokens, + num_completion_tokens=num_completion_tokens, + ) + if return_token_log_probs: + output.tokens = tokens + + metrics_gateway.emit_batch_completions_metric( + model, use_tool, num_prompt_tokens, num_completion_tokens, is_finetuned + ) + + outputs.append(output) + return outputs + + +def get_gpu_free_memory(): # pragma: no cover + """Get GPU free memory using nvidia-smi.""" + try: + output = subprocess.run( + ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"], + capture_output=True, + text=True, + ).stdout + gpu_memory = [int(x) for x in output.strip().split("\n")] + return gpu_memory + except Exception as e: + print(f"Error getting GPU memory: {e}") + return None + + +def check_unknown_startup_memory_usage(): # pragma: no cover + """Check for unknown memory usage at startup.""" + gpu_free_memory = get_gpu_free_memory() + if gpu_free_memory is not None: + print(f"GPU free memory at startup in MB: {gpu_free_memory}") + min_mem = min(gpu_free_memory) + max_mem = max(gpu_free_memory) + if max_mem - min_mem > 10: + print( + f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}." + ) + try: + output = subprocess.run( + ["fuser -v /dev/nvidia*"], + shell=True, # nosemgrep + capture_output=True, + text=True, + ).stdout + print(f"Processes using GPU: {output}") + except Exception as e: + print(f"Error getting processes using GPU: {e}") + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config-file-data", + "--config_file_data", + type=str, + default=None, + help="Optional override for the config file data, as a json string", + ) + args = parser.parse_args() + + check_unknown_startup_memory_usage() + asyncio.run(batch_inference(args.config_file_data)) diff --git a/server/llm_engine_server/inference/common.py b/model-engine/model_engine_server/inference/common.py similarity index 91% rename from server/llm_engine_server/inference/common.py rename to model-engine/model_engine_server/inference/common.py index e6191ca6..b8ddfea0 100644 --- a/server/llm_engine_server/inference/common.py +++ b/model-engine/model_engine_server/inference/common.py @@ -9,15 +9,15 @@ import boto3 import cloudpickle -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request, RequestSchema -from llm_engine_server.common.io import open_wrapper -from llm_engine_server.common.serialization_utils import b64_to_python_json -from llm_engine_server.core.loggers import make_logger -from llm_engine_server.core.utils.timer import timer -from llm_engine_server.domain.entities import ModelEndpointConfig -from llm_engine_server.inference.service_requests import make_request +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request, RequestSchema +from model_engine_server.common.io import open_wrapper +from model_engine_server.common.serialization_utils import b64_to_python_json +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.utils.timer import timer +from model_engine_server.domain.entities import ModelEndpointConfig +from model_engine_server.inference.service_requests import make_request -logger = make_logger(__name__) +logger = make_logger(logger_name()) s3_client = None @@ -117,7 +117,6 @@ def load_predict_fn_or_cls(): return predict_fn_inner else: logger.info("Loading bundle from serialized object") - # e.g. s3://scale-ml/hosted-model-inference/predict_fns/abc123 with timer(logger=logger, name="download_and_deserialize_cloudpickle_bundle"): with open_wrapper(bundle_url, "rb") as f: @@ -131,7 +130,6 @@ def load_predict_fn_or_cls(): if "model" in bundle: model = bundle["model"] elif "load_model_fn" in bundle: - # e.g. s3://scale-ml/hosted-model-inference/tf-saved-models/tf-cpu-efficientdet-abc123.tar.gz with timer(logger=logger, name="load_model_fn"): if deserialized_config is None: model = bundle["load_model_fn"]() @@ -200,7 +198,7 @@ def predict_on_url(predict_fn: Callable, request_url: str, return_pickled: bool) def predict_on_args( predict_fn: Callable, inputs: RequestSchema, return_pickled: bool ) -> Dict[str, str]: - inputs_kwargs = inputs.__root__ + inputs_kwargs = inputs.root output = predict_fn(**inputs_kwargs) if return_pickled: @@ -268,14 +266,14 @@ def get_endpoint_config(): def is_sensitive_envvar(var): - return var.startswith("LLM_ENGINE_") or var.startswith("HMI_") + return var.startswith("LAUNCH_") or var.startswith("HMI_") def unset_sensitive_envvars(): # Since all the pods are in the same namespace as of now, there are env vars e.g. - # `LLM_ENGINE__...` that store the IPs of various services (and also leak that the services exist) + # `LAUNCH__...` that store the IPs of various services (and also leak that the services exist) # Let's unset them here - # The names seem to be the name of the deployment, which always starts with `LLM_ENGINE_` or `HMI_`. + # The names seem to be the name of the deployment, which always starts with `LAUNCH_` or `HMI_`. logger.info("Unsetting environment variables...") sensitive_envvars = [var for var in os.environ if is_sensitive_envvar(var)] for var in sensitive_envvars: diff --git a/model-engine/model_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml b/model-engine/model_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml new file mode 100644 index 00000000..0c9b43b4 --- /dev/null +++ b/model-engine/model_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml @@ -0,0 +1,22 @@ +forwarder: + sync: + user_port: 5005 + user_hostname: "localhost" + use_grpc: false + predict_route: "/predict" + healthcheck_route: "/readyz" + batch_route: null + model_engine_unwrap: false + serialize_results_as_string: false + wrap_response: false + forward_http_status: true + async: + user_port: 5005 + user_hostname: "localhost" + use_grpc: false + predict_route: "/predict" + healthcheck_route: "/readyz" + batch_route: null + model_engine_unwrap: false + serialize_results_as_string: false + wrap_response: false diff --git a/model-engine/model_engine_server/inference/configs/service--forwarder.yaml b/model-engine/model_engine_server/inference/configs/service--forwarder.yaml new file mode 100644 index 00000000..fea277db --- /dev/null +++ b/model-engine/model_engine_server/inference/configs/service--forwarder.yaml @@ -0,0 +1,20 @@ +forwarder: + sync: + user_port: 5005 + user_hostname: "localhost" + use_grpc: false + predict_route: "/predict" + healthcheck_route: "/readyz" + batch_route: null + model_engine_unwrap: true + serialize_results_as_string: true + forward_http_status: true + async: + user_port: 5005 + user_hostname: "localhost" + use_grpc: false + predict_route: "/predict" + healthcheck_route: "/readyz" + batch_route: null + model_engine_unwrap: true + serialize_results_as_string: true diff --git a/server/llm_engine_server/inference/configs/service--http_forwarder.yaml b/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml similarity index 71% rename from server/llm_engine_server/inference/configs/service--http_forwarder.yaml rename to model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml index f37694f8..bfdb6553 100644 --- a/server/llm_engine_server/inference/configs/service--http_forwarder.yaml +++ b/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml @@ -6,13 +6,17 @@ forwarder: predict_route: "/predict" healthcheck_route: "/readyz" batch_route: null - llm_engine_unwrap: true + model_engine_unwrap: true serialize_results_as_string: true + forward_http_status: true + extra_routes: [] stream: user_port: 5005 user_hostname: "localhost" predict_route: "/stream" healthcheck_route: "/readyz" batch_route: null - llm_engine_unwrap: true + model_engine_unwrap: true serialize_results_as_string: false + extra_routes: [] + max_concurrency: 100 diff --git a/server/llm_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py b/model-engine/model_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py similarity index 60% rename from server/llm_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py rename to model-engine/model_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py index e23c2c74..15602563 100644 --- a/server/llm_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py +++ b/model-engine/model_engine_server/inference/domain/gateways/inference_monitoring_metrics_gateway.py @@ -5,7 +5,7 @@ Used to calculate proportion of successful/unsuccessful requests, differentiates between docker build vs other errors. -(Copy of llm_engine/domain/gateways/monitoring_metrics_gateway.py but used purely for +(Copy of model_engine_server/domain/gateways/monitoring_metrics_gateway.py but used purely for inference to avoid importing stuff in user code that we don't need.) """ @@ -30,3 +30,23 @@ def emit_successful_post_inference_hook(self, hook: str): Args: hook: The name of the hook """ + + @abstractmethod + def emit_async_task_received_metric(self, queue_name: str): + """ + Async task received metric + + Args: + queue_name: The name of the Celery queue + """ + pass + + @abstractmethod + def emit_async_task_stuck_metric(self, queue_name: str): + """ + Async task stuck metric + + Args: + queue_name: The name of the Celery queue + """ + pass diff --git a/model-engine/model_engine_server/inference/domain/gateways/streaming_storage_gateway.py b/model-engine/model_engine_server/inference/domain/gateways/streaming_storage_gateway.py new file mode 100644 index 00000000..add325bc --- /dev/null +++ b/model-engine/model_engine_server/inference/domain/gateways/streaming_storage_gateway.py @@ -0,0 +1,19 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict + + +class StreamingStorageGateway(ABC): + """ + Base class for a gateway that stores data through a streaming mechanism. + """ + + @abstractmethod + def put_record(self, stream_name: str, record: Dict[str, Any]) -> Dict[str, Any]: + """ + Put a record into a streaming storage mechanism. + + Args: + stream_name: The name of the stream. + record: The record to put into the stream. + """ + pass diff --git a/model-engine/model_engine_server/inference/domain/gateways/usage_metrics_gateway.py b/model-engine/model_engine_server/inference/domain/gateways/usage_metrics_gateway.py new file mode 100644 index 00000000..64161b3c --- /dev/null +++ b/model-engine/model_engine_server/inference/domain/gateways/usage_metrics_gateway.py @@ -0,0 +1,28 @@ +from abc import ABC, abstractmethod +from typing import Dict + + +class UsageMetricsGateway(ABC): + """ + Base class for gateway that emits usage metrics to some store of metrics, e.g. Datadog or + Platform Money Infra. + + Inside inference/ because otherwise we import tons of stuff (in particular hmi_config) that + isn't safe to import inside of the inference code (since it contains sensitive data) + + TODO this code (at least in its current form) should be considered temporary, it's to enable + instantml billing + """ + + @abstractmethod + def emit_task_call_metric(self, idempotency_token: str, tags: Dict[str, str]): + """ + Emits the billing event to the billing queue + Args: + idempotency_token: Some per-request token + tags: User-defined tags to get passed to billing. Should be for internal only. + Right now `tags` is pretty strictly formatted, + and reflects the scale FinancialEvent schema (see EventbridgeUsageMetricsGateway) + + """ + pass diff --git a/server/llm_engine_server/inference/download_and_inject_bundle.py b/model-engine/model_engine_server/inference/download_and_inject_bundle.py similarity index 93% rename from server/llm_engine_server/inference/download_and_inject_bundle.py rename to model-engine/model_engine_server/inference/download_and_inject_bundle.py index 8637be59..7fa1f726 100644 --- a/server/llm_engine_server/inference/download_and_inject_bundle.py +++ b/model-engine/model_engine_server/inference/download_and_inject_bundle.py @@ -2,9 +2,9 @@ import os import shutil -from llm_engine_server.core.loggers import make_logger +from model_engine_server.core.loggers import logger_name, make_logger -logger = make_logger(__name__) +logger = make_logger(logger_name()) LOCAL_BUNDLE_PATH = os.getenv("LOCAL_BUNDLE_PATH", "") LOAD_MODEL_MODULE_PATH = os.getenv("LOAD_MODEL_MODULE_PATH", "") diff --git a/server/llm_engine_server/infra/__init__.py b/model-engine/model_engine_server/inference/forwarding/__init__.py similarity index 100% rename from server/llm_engine_server/infra/__init__.py rename to model-engine/model_engine_server/inference/forwarding/__init__.py diff --git a/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py new file mode 100644 index 00000000..a4daff0c --- /dev/null +++ b/model-engine/model_engine_server/inference/forwarding/celery_forwarder.py @@ -0,0 +1,207 @@ +import argparse +import json +from datetime import datetime, timedelta +from typing import Any, Dict, Optional, TypedDict, Union + +from celery import Celery, Task, states +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.celery import ( + DEFAULT_TASK_VISIBILITY_SECONDS, + TaskVisibility, + celery_app, +) +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.utils.format import format_stacktrace +from model_engine_server.inference.forwarding.forwarding import ( + Forwarder, + LoadForwarder, + load_named_config, +) +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( + DatadogInferenceMonitoringMetricsGateway, +) + +logger = make_logger(logger_name()) + + +class ErrorResponse(TypedDict): + """The response payload for any inference request that encountered an error.""" + + error: str + error_metadata: str + + +def raw_celery_response(backend, task_id: str) -> Dict[str, Any]: + key_info: str = backend.get_key_for_task(task_id) + info_as_str: str = backend.get(key_info) + info: dict = json.loads(info_as_str) + return info + + +def error_response(msg: str, e_unhandled: Exception) -> ErrorResponse: + stacktrace = format_stacktrace(e_unhandled) + return { + "error": str(e_unhandled), + "error_metadata": f"{msg}\n{stacktrace}", + } + + +def create_celery_service( + forwarder: Forwarder, + task_visibility: TaskVisibility, + broker_type: str, + backend_protocol: str, + queue_name: Optional[str] = None, + sqs_url: Optional[str] = None, +) -> Celery: + """ + Creates a celery application. + Returns: + app (celery.app.base.Celery): Celery app. + exec_func (celery.local.PromiseProxy): Callable task function. + """ + + app: Celery = celery_app( + name=None, + s3_bucket=infra_config().s3_bucket, + aws_role=infra_config().profile_ml_inference_worker, + task_visibility=task_visibility, + broker_type=broker_type, + broker_transport_options=( + {"predefined_queues": {queue_name: {"url": sqs_url}}} + if broker_type == str(BrokerType.SQS.value) + else None + ), + backend_protocol=backend_protocol, + ) + + monitoring_metrics_gateway = DatadogInferenceMonitoringMetricsGateway() + + class ErrorHandlingTask(Task): + """Sets a 'custom' field with error in the Task response for FAILURE. + + Used when services are ran via the Celery backend. + """ + + def after_return( + self, status: str, retval: Union[dict, Exception], task_id: str, args, kwargs, einfo + ) -> None: + """Handler that ensures custom error response information is available whenever a Task fails. + + Specifically, whenever the task's :param:`status` is `"FAILURE"` and the return value + :param:`retval` is an `Exception`, this handler extracts information from the `Exception` + and constructs a custom error response JSON value (see :func:`error_response` for details). + + This handler then re-propagates the Celery-required exception information (`"exc_type"` and + `"exc_message"`) while adding this new error response information under the `"custom"` key. + """ + if status == states.FAILURE and isinstance(retval, Exception): + logger.warning(f"Setting custom error response for failed task {task_id}") + + info: dict = raw_celery_response(self.backend, task_id) + result: dict = info["result"] + err: Exception = retval + + error_payload = error_response("Internal failure", err) + + # Inspired by pattern from: + # https://www.distributedpython.com/2018/09/28/celery-task-states/ + self.update_state( + state=states.FAILURE, + meta={ + "exc_type": result["exc_type"], + "exc_message": result["exc_message"], + "custom": json.dumps(error_payload, indent=False), + }, + ) + request_params = args[0] + request_params_pydantic = EndpointPredictV1Request.parse_obj(request_params) + if forwarder.post_inference_hooks_handler: + forwarder.post_inference_hooks_handler.handle(request_params_pydantic, retval, task_id) # type: ignore + + # See documentation for options: + # https://docs.celeryproject.org/en/stable/userguide/tasks.html#list-of-options + @app.task(base=ErrorHandlingTask, name=LIRA_CELERY_TASK_NAME, track_started=True) + def exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs): + if len(ignored_args) > 0: + logger.warning(f"Ignoring {len(ignored_args)} positional arguments: {ignored_args=}") + if len(ignored_kwargs) > 0: + logger.warning(f"Ignoring {len(ignored_kwargs)} keyword arguments: {ignored_kwargs=}") + try: + monitoring_metrics_gateway.emit_async_task_received_metric(queue_name) + result = forwarder(payload) + request_duration = datetime.now() - arrival_timestamp + if request_duration > timedelta(seconds=DEFAULT_TASK_VISIBILITY_SECONDS): + monitoring_metrics_gateway.emit_async_task_stuck_metric(queue_name) + return result + except Exception: + logger.exception("Celery service failed to respond to request.") + raise + + # Have celery service also accept pre-LIRA celery task name to ensure no downtime + # when transitioning from pre-LIRA single container architecture to LIRA + # multi-container-architecture. + @app.task( + base=ErrorHandlingTask, + name=DEFAULT_CELERY_TASK_NAME, + track_started=True, + ) + def exec_func_pre_lira(payload, arrival_timestamp, *ignored_args, **ignored_kwargs): + return exec_func(payload, arrival_timestamp, *ignored_args, **ignored_kwargs) + + return app + + +def start_celery_service( + app: Celery, + queue: str, + concurrency: int, +) -> None: + worker = app.Worker( + queues=[queue], + concurrency=concurrency, + loglevel="INFO", + optimization="fair", + # pool="solo" argument fixes the known issues of celery and some of the libraries. + # Particularly asyncio and torchvision transformers. + pool="solo", + ) + worker.start() + + +def entrypoint(): + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--set", type=str, action="append") + parser.add_argument("--task-visibility", type=str, required=True) + parser.add_argument("--num-workers", type=int, required=True) + parser.add_argument("--broker-type", type=str, default=None) + parser.add_argument("--backend-protocol", type=str, default="s3") + parser.add_argument("--queue", type=str, required=True) + parser.add_argument("--sqs-url", type=str, default=None) + + args = parser.parse_args() + + if args.broker_type is None: + args.broker_type = str(BrokerType.SQS.value if args.sqs_url else BrokerType.REDIS.value) + + forwarder_config = load_named_config(args.config, args.set) + forwarder_loader = LoadForwarder(**forwarder_config["async"]) + forwader = forwarder_loader.load(None, None) + + app = create_celery_service( + forwader, + TaskVisibility.VISIBILITY_24H, + args.broker_type, + args.backend_protocol, + args.queue, + args.sqs_url, + ) + start_celery_service(app, args.queue, args.num_workers) + + +if __name__ == "__main__": + entrypoint() diff --git a/model-engine/model_engine_server/inference/forwarding/echo_server.py b/model-engine/model_engine_server/inference/forwarding/echo_server.py new file mode 100644 index 00000000..3d6333a3 --- /dev/null +++ b/model-engine/model_engine_server/inference/forwarding/echo_server.py @@ -0,0 +1,72 @@ +""" +This file is for testing purposes only. It serves as simple server to mock a deployed model. +""" + +import argparse +import asyncio +import subprocess + +from fastapi import FastAPI, Request +from fastapi.responses import JSONResponse +from sse_starlette.sse import EventSourceResponse + +app = FastAPI() + + +@app.get("/health") +@app.get("/healthz") +@app.get("/readyz") +def healthcheck(): + return "OK" + + +@app.post("/v1/chat/completions") +@app.post("/predict") +async def predict(request: Request): + dictionary = await request.json() + print("Received request", dictionary, flush=True) + if "delay" in dictionary: + await asyncio.sleep(dictionary["delay"]) + return dictionary + + +@app.post("/predict500") +async def predict500(request: Request): + response = JSONResponse(content=await request.json(), status_code=500) + return response + + +@app.post("/stream") +async def stream(request: Request): + value = (await request.body()).decode() + return EventSourceResponse([{"data": value}].__iter__()) + + +def entrypoint(): + parser = argparse.ArgumentParser() + parser.add_argument("--num-workers", type=int, default=1) + parser.add_argument("--host", type=str, default="[::]") + parser.add_argument("--port", type=int, default=5009) + + args, extra_args = parser.parse_known_args() + + command = [ + "gunicorn", + "--bind", + f"{args.host}:{args.port}", + "--timeout", + "1200", + "--keep-alive", + "2", + "--worker-class", + "uvicorn.workers.UvicornWorker", + "--workers", + str(args.num_workers), + "model_engine_server.inference.forwarding.echo_server:app", + *extra_args, + ] + subprocess.run(command) + + +if __name__ == "__main__": + entrypoint() diff --git a/server/llm_engine_server/inference/forwarding/forwarding.py b/model-engine/model_engine_server/inference/forwarding/forwarding.py similarity index 54% rename from server/llm_engine_server/inference/forwarding/forwarding.py rename to model-engine/model_engine_server/inference/forwarding/forwarding.py index 0517bc65..c3970107 100644 --- a/server/llm_engine_server/inference/forwarding/forwarding.py +++ b/model-engine/model_engine_server/inference/forwarding/forwarding.py @@ -1,19 +1,28 @@ +import ast import json import os import time from dataclasses import dataclass from pathlib import Path -from typing import Any, Iterator, Optional, Sequence, Tuple +from typing import Any, AsyncGenerator, Iterable, List, Optional, Sequence, Tuple +import aiohttp +import orjson import requests import sseclient -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.loggers import logger_name, make_logger -from llm_engine_server.inference.common import get_endpoint_config -from llm_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( +import yaml +from fastapi import HTTPException +from fastapi.responses import JSONResponse +from model_engine_server.common.aiohttp_sse_client import EventSource +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.inference.common import get_endpoint_config +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( DatadogInferenceMonitoringMetricsGateway, ) -from llm_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler +from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import ( + FirehoseStreamingStorageGateway, +) +from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler __all__: Sequence[str] = ( "Forwarder", @@ -31,10 +40,10 @@ DEFAULT_PORT: int = 5005 -class LLMEngineSerializationMixin: - """Mixin class for optionally wrapping LLMEngine requests.""" +class ModelEngineSerializationMixin: + """Mixin class for optionally wrapping Model Engine requests.""" - llm_engine_unwrap: bool + model_engine_unwrap: bool serialize_results_as_string: bool def _get_serialize_results_as_string_value( @@ -46,7 +55,7 @@ def _get_serialize_results_as_string_value( return serialize_results_as_string elif KEY_SERIALIZE_RESULTS_AS_STRING in json_payload: - serialize_results_as_string = json_payload[KEY_SERIALIZE_RESULTS_AS_STRING] + serialize_results_as_string = bool(json_payload[KEY_SERIALIZE_RESULTS_AS_STRING]) logger.warning( f"Found '{KEY_SERIALIZE_RESULTS_AS_STRING}' in payload! " f"Overriding {self.serialize_results_as_string=} with " @@ -68,15 +77,17 @@ def _get_serialize_results_as_string_value( def unwrap_json_payload(self, json_payload: Any) -> Tuple[Any, bool]: # TODO: eventually delete serialize_results_as_string: Optional[bool] = None + # IF we get a feature update in model_engine where it's able to allow a user to + # request this from the API, then we can determine that here. + # (NOTE: This is _potential_ future behavior) serialize_results_as_string = self._get_serialize_results_as_string_value( serialize_results_as_string, json_payload, # type: ignore ) - if self.llm_engine_unwrap: + if self.model_engine_unwrap: logger.info(f"Unwrapping {json_payload.keys()=}") - json_payload = json_payload["args"] - # TODO: eventually delete + json_payload = json_payload.get("args", json_payload) serialize_results_as_string = self._get_serialize_results_as_string_value( serialize_results_as_string, json_payload, # type: ignore @@ -91,16 +102,35 @@ def unwrap_json_payload(self, json_payload: Any) -> Tuple[Any, bool]: @staticmethod def get_response_payload(using_serialize_results_as_string: bool, response: Any): - # LLMEngine expects a JSON object with a "result" key. + # Model Engine expects a JSON object with a "result" key. if using_serialize_results_as_string: response_as_string: str = json.dumps(response) return {"result": response_as_string} return {"result": response} + @staticmethod + def get_response_payload_stream(using_serialize_results_as_string: bool, response: str): + """Event stream is needs to be treated as a stream of strings, not JSON objects""" + if using_serialize_results_as_string: + return {"result": response} + + return {"result": parse_to_object_or_string(response)} + + +def parse_to_object_or_string(value: str) -> object: + try: + return json.loads(value) + except json.JSONDecodeError: + return value + + +def _serialize_json(data) -> str: + return orjson.dumps(data).decode() + @dataclass -class Forwarder(LLMEngineSerializationMixin): +class Forwarder(ModelEngineSerializationMixin): """Forwards inference requests to another service via HTTP POST. Expects this user-defined inference service to be running on localhost. However, @@ -115,26 +145,71 @@ class Forwarder(LLMEngineSerializationMixin): """ predict_endpoint: str - llm_engine_unwrap: bool + model_engine_unwrap: bool serialize_results_as_string: bool - post_inference_hooks_handler: PostInferenceHooksHandler wrap_response: bool + forward_http_status: bool + post_inference_hooks_handler: Optional[PostInferenceHooksHandler] = None + + async def forward(self, json_payload: Any) -> Any: + json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) + json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload + + logger.info(f"Accepted request, forwarding {json_payload_repr=}") + + try: + async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient: + response_raw = await aioclient.post( + self.predict_endpoint, + json=json_payload, + headers={"Content-Type": "application/json"}, + ) + response = await response_raw.json( + content_type=None + ) # [Bug] upstream service doesn't always have the content type header set which causes aiohttp to error + + except Exception: + logger.exception( + f"Failed to get response for request ({json_payload_repr}) " + "from user-defined inference service." + ) + raise + if isinstance(response, dict): + logger.info( + f"Got response from user-defined service: {response.keys()=}, {response_raw.status=}" + ) + elif isinstance(response, list): + logger.info( + f"Got response from user-defined service: {len(response)=}, {response_raw.status=}" + ) + else: + logger.info( + f"Got response from user-defined service: {response=}, {response_raw.status=}" + ) + + if self.wrap_response: + response = self.get_response_payload(using_serialize_results_as_string, response) + + if self.forward_http_status: + return JSONResponse(content=response, status_code=response_raw.status) + else: + return response def __call__(self, json_payload: Any) -> Any: - request_obj = EndpointPredictV1Request.parse_obj(json_payload) json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload logger.info(f"Accepted request, forwarding {json_payload_repr=}") try: - response: Any = requests.post( + response_raw: Any = requests.post( self.predict_endpoint, json=json_payload, headers={ "Content-Type": "application/json", }, - ).json() + ) + response = response_raw.json() except Exception: logger.exception( f"Failed to get response for request ({json_payload_repr}) " @@ -142,18 +217,25 @@ def __call__(self, json_payload: Any) -> Any: ) raise if isinstance(response, dict): - logger.info(f"Got response from user-defined service: {response.keys()=}") + logger.info( + f"Got response from user-defined service: {response.keys()=}, {response_raw.status_code=}" + ) elif isinstance(response, list): - logger.info(f"Got response from user-defined service: {len(response)=}") + logger.info( + f"Got response from user-defined service: {len(response)=}, {response_raw.status_code=}" + ) else: - logger.info(f"Got response from user-defined service: {response=}") + logger.info( + f"Got response from user-defined service: {response=}, {response_raw.status_code=}" + ) if self.wrap_response: response = self.get_response_payload(using_serialize_results_as_string, response) - # TODO: we actually want to do this after we've returned the response. - self.post_inference_hooks_handler.handle(request_obj, response) - return response + if self.forward_http_status: + return JSONResponse(content=response, status_code=response_raw.status_code) + else: + return response @dataclass(frozen=True) @@ -174,12 +256,12 @@ class LoadForwarder: predict_route: str = "/predict" healthcheck_route: str = "/readyz" batch_route: Optional[str] = None - llm_engine_unwrap: bool = True - # TODO: this is a workaround + model_engine_unwrap: bool = True serialize_results_as_string: bool = True wrap_response: bool = True + forward_http_status: bool = False - def load(self, resources: Path, cache: Any) -> Forwarder: + def load(self, resources: Optional[Path], cache: Any) -> Forwarder: if self.use_grpc: raise NotImplementedError( "User-defined service **MUST** use HTTP at the moment. " @@ -236,7 +318,7 @@ def endpoint(route: str) -> str: logger.info(f"Waiting for user-defined service to be ready at {hc}...") time.sleep(1) - logger.info(f"Unwrapping spellbook payload formatting?: {self.llm_engine_unwrap}") + logger.info(f"Unwrapping model engine payload formatting?: {self.model_engine_unwrap}") logger.info(f"Serializing result as string?: {self.serialize_results_as_string}") if ENV_SERIALIZE_RESULTS_AS_STRING in os.environ: @@ -257,28 +339,39 @@ def endpoint(route: str) -> str: else: serialize_results_as_string = self.serialize_results_as_string - endpoint_config = get_endpoint_config() - handler = PostInferenceHooksHandler( - endpoint_name=endpoint_config.endpoint_name, - bundle_name=endpoint_config.bundle_name, - post_inference_hooks=endpoint_config.post_inference_hooks, - user_id=endpoint_config.user_id, - default_callback_url=endpoint_config.default_callback_url, - default_callback_auth=endpoint_config.default_callback_auth, - monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), - ) + try: + endpoint_config = get_endpoint_config() + handler = PostInferenceHooksHandler( + endpoint_name=endpoint_config.endpoint_name, + bundle_name=endpoint_config.bundle_name, + post_inference_hooks=endpoint_config.post_inference_hooks, + user_id=endpoint_config.user_id, + billing_queue=endpoint_config.billing_queue, + billing_tags=endpoint_config.billing_tags, + default_callback_url=endpoint_config.default_callback_url, + default_callback_auth=endpoint_config.default_callback_auth, + monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id=endpoint_config.endpoint_id, + endpoint_type=endpoint_config.endpoint_type, + bundle_id=endpoint_config.bundle_id, + labels=endpoint_config.labels, + streaming_storage_gateway=FirehoseStreamingStorageGateway(), + ) + except Exception: + handler = None return Forwarder( predict_endpoint=pred, - llm_engine_unwrap=self.llm_engine_unwrap, + model_engine_unwrap=self.model_engine_unwrap, serialize_results_as_string=serialize_results_as_string, post_inference_hooks_handler=handler, wrap_response=self.wrap_response, + forward_http_status=self.forward_http_status, ) @dataclass -class StreamingForwarder(LLMEngineSerializationMixin): +class StreamingForwarder(ModelEngineSerializationMixin): """Forwards inference requests to another service via HTTP POST. Expects this user-defined inference service to be running on localhost. However, @@ -294,11 +387,44 @@ class StreamingForwarder(LLMEngineSerializationMixin): """ predict_endpoint: str - llm_engine_unwrap: bool + model_engine_unwrap: bool serialize_results_as_string: bool - post_inference_hooks_handler: PostInferenceHooksHandler # unused for now + post_inference_hooks_handler: Optional[PostInferenceHooksHandler] = None # unused for now - def __call__(self, json_payload: Any) -> Iterator[Any]: + async def forward(self, json_payload: Any) -> AsyncGenerator[Any, None]: # pragma: no cover + json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) + json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload + + logger.info(f"Accepted request, forwarding {json_payload_repr=}") + + try: + response: aiohttp.ClientResponse + async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient: + response = await aioclient.post( + self.predict_endpoint, + json=json_payload, + headers={"Content-Type": "application/json"}, + ) + + if response.status != 200: + raise HTTPException( + status_code=response.status, detail=await response.json(content_type=None) + ) # [Bug] upstream service doesn't always have the content type header set which causes aiohttp to error + + async with EventSource(response=response) as event_source: + async for event in event_source: + yield self.get_response_payload_stream( + using_serialize_results_as_string, event.data + ) + + except Exception: + logger.exception( + f"Failed to get response for request ({json_payload_repr}) " + "from user-defined inference service." + ) + raise + + def __call__(self, json_payload: Any) -> Iterable[Any]: json_payload, using_serialize_results_as_string = self.unwrap_json_payload(json_payload) json_payload_repr = json_payload.keys() if hasattr(json_payload, "keys") else json_payload @@ -313,6 +439,10 @@ def __call__(self, json_payload: Any) -> Iterator[Any]: }, stream=True, ) + + if response.status_code != 200: + raise HTTPException(status_code=response.status_code, detail=response.json()) + except Exception: logger.exception( f"Failed to get response for request ({json_payload_repr}) " @@ -321,10 +451,14 @@ def __call__(self, json_payload: Any) -> Iterator[Any]: raise client = sseclient.SSEClient(response) - for event in client.events(): - yield self.get_response_payload( - using_serialize_results_as_string, json.loads(event.data) - ) + + def event_stream(): + for event in client.events(): + yield self.get_response_payload_stream( + using_serialize_results_as_string, event.data + ) + + return event_stream() @dataclass(frozen=True) @@ -345,10 +479,10 @@ class LoadStreamingForwarder: predict_route: str = "/predict" healthcheck_route: str = "/readyz" batch_route: Optional[str] = None - llm_engine_unwrap: bool = True + model_engine_unwrap: bool = True serialize_results_as_string: bool = False - def load(self, resources: Path, cache: Any) -> StreamingForwarder: + def load(self, resources: Optional[Path], cache: Any) -> StreamingForwarder: if self.use_grpc: raise NotImplementedError( "User-defined service **MUST** use HTTP at the moment. " @@ -405,7 +539,7 @@ def endpoint(route: str) -> str: logger.info(f"Waiting for user-defined service to be ready at {hc}...") time.sleep(1) - logger.info(f"Unwrapping spellbook payload formatting?: {self.llm_engine_unwrap}") + logger.info(f"Unwrapping model engine payload formatting?: {self.model_engine_unwrap}") logger.info(f"Serializing result as string?: {self.serialize_results_as_string}") if ENV_SERIALIZE_RESULTS_AS_STRING in os.environ: @@ -426,20 +560,88 @@ def endpoint(route: str) -> str: else: serialize_results_as_string = self.serialize_results_as_string - endpoint_config = get_endpoint_config() - handler = PostInferenceHooksHandler( - endpoint_name=endpoint_config.endpoint_name, - bundle_name=endpoint_config.bundle_name, - post_inference_hooks=endpoint_config.post_inference_hooks, - user_id=endpoint_config.user_id, - default_callback_url=endpoint_config.default_callback_url, - default_callback_auth=endpoint_config.default_callback_auth, - monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), - ) + try: + endpoint_config = get_endpoint_config() + handler = PostInferenceHooksHandler( + endpoint_name=endpoint_config.endpoint_name, + bundle_name=endpoint_config.bundle_name, + post_inference_hooks=endpoint_config.post_inference_hooks, + user_id=endpoint_config.user_id, + billing_queue=endpoint_config.billing_queue, + billing_tags=endpoint_config.billing_tags, + default_callback_url=endpoint_config.default_callback_url, + default_callback_auth=endpoint_config.default_callback_auth, + monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id=endpoint_config.endpoint_id, + endpoint_type=endpoint_config.endpoint_type, + bundle_id=endpoint_config.bundle_id, + labels=endpoint_config.labels, + streaming_storage_gateway=FirehoseStreamingStorageGateway(), + ) + except Exception: + handler = None return StreamingForwarder( predict_endpoint=pred, - llm_engine_unwrap=self.llm_engine_unwrap, + model_engine_unwrap=self.model_engine_unwrap, serialize_results_as_string=serialize_results_as_string, post_inference_hooks_handler=handler, ) + + +def load_named_config(config_uri, config_overrides=None): + with open(config_uri, "rt") as rt: + if config_uri.endswith(".json"): + return json.load(rt) + else: + c = yaml.safe_load(rt) + if config_overrides: + _substitute_config_overrides(c, config_overrides) + if len(c) == 1: + name = list(c.keys())[0] + c = c[name] + if "name" not in c: + c["name"] = name + return c + + +def _substitute_config_overrides(config: dict, config_overrides: List[str]) -> None: + """ + Modifies config based on config_overrides. + + config_overrides should be a list of strings of the form `key=value`, + where `key` can be of the form `key1.key2` to denote a substitution for config[key1][key2] + (nesting can be arbitrarily deep). + """ + for override in config_overrides: + split = override.split("=") + if len(split) != 2: + raise ValueError(f"Config override {override} must contain exactly one =") + key_path, value = split + try: + _set_value(config, key_path.split("."), value) + except Exception as e: + raise ValueError(f"Error setting {key_path} to {value} in {config}") from e + + +def _cast_value(value: Any) -> Any: + if value.isdigit(): + return int(value) + elif value.startswith("[") and value.endswith("]"): + # Can't use json because it doesn't support single quotes + return ast.literal_eval(value) + else: + return value + + +def _set_value(config: dict, key_path: List[str], value: Any) -> None: + """ + Modifies config by setting the value at config[key_path[0]][key_path[1]]... to be `value`. + """ + key = key_path[0] + if len(key_path) == 1: + config[key] = _cast_value(value) + else: + if key not in config: + config[key] = dict() + _set_value(config[key], key_path[1:], value) diff --git a/model-engine/model_engine_server/inference/forwarding/http_forwarder.py b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py new file mode 100644 index 00000000..89fcb3fb --- /dev/null +++ b/model-engine/model_engine_server/inference/forwarding/http_forwarder.py @@ -0,0 +1,263 @@ +import argparse +import asyncio +import os +import signal +from functools import lru_cache +from typing import Any, Dict, Optional + +import orjson +import uvicorn +from fastapi import BackgroundTasks, Depends, FastAPI +from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.inference.forwarding.forwarding import ( + Forwarder, + LoadForwarder, + LoadStreamingForwarder, + StreamingForwarder, + load_named_config, +) +from sse_starlette import EventSourceResponse + +logger = make_logger(logger_name()) + + +def get_config(): + overrides = os.getenv("CONFIG_OVERRIDES") + config_overrides = None + if overrides is not None: + config_overrides = overrides.split(";") + return load_named_config( + os.getenv("CONFIG_FILE"), + config_overrides, + ) + + +def get_forwarder_loader(destination_path: Optional[str] = None) -> LoadForwarder: + config = get_config()["sync"] + if "extra_routes" in config: + del config["extra_routes"] + if destination_path: + config["predict_route"] = destination_path + forwarder_loader = LoadForwarder(**config) + return forwarder_loader + + +def get_streaming_forwarder_loader( + destination_path: Optional[str] = None, +) -> LoadStreamingForwarder: + config = get_config()["stream"] + if "extra_routes" in config: + del config["extra_routes"] + if destination_path: + config["predict_route"] = destination_path + streaming_forwarder_loader = LoadStreamingForwarder(**config) + return streaming_forwarder_loader + + +@lru_cache() +def get_concurrency_limiter() -> MultiprocessingConcurrencyLimiter: + config = get_config() + concurrency = int(config.get("max_concurrency", 100)) + return MultiprocessingConcurrencyLimiter( + concurrency=concurrency, fail_on_concurrency_limit=True + ) + + +@lru_cache() +def load_forwarder(destination_path: Optional[str] = None) -> Forwarder: + return get_forwarder_loader(destination_path).load(None, None) + + +@lru_cache() +def load_streaming_forwarder(destination_path: Optional[str] = None) -> StreamingForwarder: + return get_streaming_forwarder_loader(destination_path).load(None, None) + + +async def predict( + request: EndpointPredictV1Request, + background_tasks: BackgroundTasks, + forwarder: Forwarder = Depends(load_forwarder), + limiter: MultiprocessingConcurrencyLimiter = Depends(get_concurrency_limiter), +): + with limiter: + try: + response = await forwarder.forward(request.model_dump()) + if forwarder.post_inference_hooks_handler: + background_tasks.add_task( + forwarder.post_inference_hooks_handler.handle, request, response + ) + return response + except Exception: + logger.error(f"Failed to decode payload from: {request}") + raise + + +async def stream( + request: EndpointPredictV1Request, + forwarder: StreamingForwarder = Depends(load_streaming_forwarder), + limiter: MultiprocessingConcurrencyLimiter = Depends(get_concurrency_limiter), +): + with limiter: + try: + payload = request.model_dump() + except Exception: + logger.error(f"Failed to decode payload from: {request}") + raise + else: + logger.debug(f"Received request: {payload}") + + responses = forwarder.forward(payload) + # We fetch the first response to check if upstream request was successful + # If it was not, this will raise the corresponding HTTPException + # If it was, we will proceed to the event generator + initial_response = await responses.__anext__() + + async def event_generator(): + yield {"data": orjson.dumps(initial_response).decode("utf-8")} + + async for response in responses: + yield {"data": orjson.dumps(response).decode("utf-8")} + + return EventSourceResponse(event_generator()) + + +async def serve_http(app: FastAPI, **uvicorn_kwargs: Any): # pragma: no cover + logger.info("Available routes are:") + for route in app.routes: + methods = getattr(route, "methods", None) + path = getattr(route, "path", None) + + if methods is None or path is None: + continue + + logger.info("Route: %s, Methods: %s", path, ", ".join(methods)) + + config = uvicorn.Config(app, **uvicorn_kwargs) + server = uvicorn.Server(config) + + loop = asyncio.get_running_loop() + + server_task = loop.create_task(server.serve()) + + def signal_handler() -> None: + # prevents the uvicorn signal handler to exit early + server_task.cancel() + + async def dummy_shutdown() -> None: + pass + + loop.add_signal_handler(signal.SIGINT, signal_handler) + loop.add_signal_handler(signal.SIGTERM, signal_handler) + + try: + await server_task + return dummy_shutdown() + except asyncio.CancelledError: + logger.info("Gracefully stopping http server") + return server.shutdown() + + +async def run_server(args, **uvicorn_kwargs) -> None: # pragma: no cover + app = await init_app() + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + **uvicorn_kwargs, + ) + + await shutdown_task + + +async def init_app(): + app = FastAPI() + + def healthcheck(): + return "OK" + + def add_extra_routes(app: FastAPI): + """Read extra_routes from config and dynamically add routes to app""" + config = get_config() + sync_forwarders: Dict[str, Forwarder] = dict() + stream_forwarders: Dict[str, StreamingForwarder] = dict() + for route in config.get("sync", {}).get("extra_routes", []): + sync_forwarders[route] = load_forwarder(route) + for route in config.get("stream", {}).get("extra_routes", []): + stream_forwarders[route] = load_streaming_forwarder(route) + + all_routes = set(list(sync_forwarders.keys()) + list(stream_forwarders.keys())) + + for route in all_routes: + + def get_sync_forwarder(route=route): + return sync_forwarders.get(route) + + def get_stream_forwarder(route=route): + return stream_forwarders.get(route) + + # This route is a catch-all for any requests that don't match the /predict or /stream routes + # It will treat the request as a streaming request if the "stream" body parameter is set to true + # NOTE: it is important for this to be defined AFTER the /predict and /stream endpoints + # because FastAPI will match the first route that matches the request path + async def predict_or_stream( + request: EndpointPredictV1Request, + background_tasks: BackgroundTasks, + sync_forwarder: Forwarder = Depends(get_sync_forwarder), + stream_forwarder: StreamingForwarder = Depends(get_stream_forwarder), + limiter=Depends(get_concurrency_limiter), + ): + if not request.args: + raise Exception("Request has no args") + if request.args.root.get("stream", False) and stream_forwarder: + return await stream(request, stream_forwarder, limiter) + elif request.args.root.get("stream") is not True and sync_forwarder: + return await predict(request, background_tasks, sync_forwarder, limiter) + else: + raise Exception("No forwarder configured for this route") + + logger.info(f"Adding route {route}") + app.add_api_route( + path=route, + endpoint=predict_or_stream, + methods=["POST"], + ) + + app.add_api_route(path="/healthz", endpoint=healthcheck, methods=["GET"]) + app.add_api_route(path="/readyz", endpoint=healthcheck, methods=["GET"]) + app.add_api_route(path="/predict", endpoint=predict, methods=["POST"]) + app.add_api_route(path="/stream", endpoint=stream, methods=["POST"]) + + add_extra_routes(app) + return app + + +def entrypoint(): # pragma: no cover + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + parser.add_argument("--num-workers", type=int, required=True) + parser.add_argument("--host", type=str, default=None) + parser.add_argument("--port", type=int, default=5000) + parser.add_argument("--set", type=str, action="append") + parser.add_argument("--graceful-timeout", type=int, default=600) + + args, extra_args = parser.parse_known_args() + + os.environ["CONFIG_FILE"] = args.config + if args.set is not None: + os.environ["CONFIG_OVERRIDES"] = ";".join(args.set) + + asyncio.run( + run_server( + args, + timeout_keep_alive=2, + timeout_graceful_shutdown=args.graceful_timeout, + workers=args.num_workers, + *extra_args, + ) + ) + + +if __name__ == "__main__": + entrypoint() diff --git a/server/llm_engine_server/infra/gateways/resources/__init__.py b/model-engine/model_engine_server/inference/infra/__init__.py similarity index 100% rename from server/llm_engine_server/infra/gateways/resources/__init__.py rename to model-engine/model_engine_server/inference/infra/__init__.py diff --git a/server/llm_engine_server/scripts/__init__.py b/model-engine/model_engine_server/inference/infra/gateways/__init__.py similarity index 100% rename from server/llm_engine_server/scripts/__init__.py rename to model-engine/model_engine_server/inference/infra/gateways/__init__.py diff --git a/model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py b/model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py new file mode 100644 index 00000000..30aca62b --- /dev/null +++ b/model-engine/model_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py @@ -0,0 +1,53 @@ +from datadog import statsd +from model_engine_server.inference.domain.gateways.inference_monitoring_metrics_gateway import ( + InferenceMonitoringMetricsGateway, +) + + +class DatadogInferenceMonitoringMetricsGateway(InferenceMonitoringMetricsGateway): + def emit_attempted_post_inference_hook(self, hook: str): + statsd.increment(f"scale_launch.post_inference_hook.{hook}.attempt") + + def emit_successful_post_inference_hook(self, hook: str): + statsd.increment(f"scale_launch.post_inference_hook.{hook}.success") + + def emit_async_task_received_metric(self, queue_name: str): + statsd.increment( + "scale_launch.async_task.received.count", tags=[f"queue_name:{queue_name}"] + ) # pragma: no cover + + def emit_async_task_stuck_metric(self, queue_name: str): + statsd.increment("scale_launch.async_task.stuck.count", tags=[f"queue_name:{queue_name}"]) + + def emit_batch_completions_metric( + self, + model: str, + use_tool: bool, + num_prompt_tokens: int, + num_completion_tokens: int, + is_finetuned: bool, + ): + tags = [ + f"model:{model}", + f"use_tool:{use_tool}", + f"is_finetuned:{is_finetuned}", + ] + statsd.increment( + "model_engine.batch_inference.vllm.generation_count", + tags=tags, + ) + statsd.increment( + "model_engine.batch_inference.vllm.token_count.total", + num_prompt_tokens + num_completion_tokens, + tags=tags, + ) + statsd.increment( + "model_engine.batch_inference.vllm.token_count.completion", + num_completion_tokens, + tags=tags, + ) + statsd.increment( + "model_engine.batch_inference.vllm.token_count.prompt", + num_prompt_tokens, + tags=tags, + ) diff --git a/model-engine/model_engine_server/inference/infra/gateways/fake_usage_metrics_gateway.py b/model-engine/model_engine_server/inference/infra/gateways/fake_usage_metrics_gateway.py new file mode 100644 index 00000000..d3e76fdf --- /dev/null +++ b/model-engine/model_engine_server/inference/infra/gateways/fake_usage_metrics_gateway.py @@ -0,0 +1,10 @@ +from typing import Dict + +from model_engine_server.inference.domain.gateways.usage_metrics_gateway import UsageMetricsGateway + + +class FakeUsageMetricsGateway(UsageMetricsGateway): + """No-op usage metrics emitter""" + + def emit_task_call_metric(self, idempotency_token: str, tags: Dict[str, str]): + pass diff --git a/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py b/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py new file mode 100644 index 00000000..9ec93f44 --- /dev/null +++ b/model-engine/model_engine_server/inference/infra/gateways/firehose_streaming_storage_gateway.py @@ -0,0 +1,65 @@ +import json +from typing import Any, Dict + +import boto3 +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import StreamPutException +from model_engine_server.inference.domain.gateways.streaming_storage_gateway import ( + StreamingStorageGateway, +) + +logger = make_logger(logger_name()) + + +class FirehoseStreamingStorageGateway(StreamingStorageGateway): + """ + A gateway that stores data through the AWS Kinesis Firehose streaming mechanism. + """ + + def __init__(self): + pass + + """ + Creates a new firehose client. + + Streams with Snowflake as a destination and the AWS profile live in different + accounts. Firehose doesn't support resource-based policies, so we need to assume + a new role to write to the stream. + """ + + def _get_firehose_client(self): + sts_session = boto3.Session(region_name=infra_config().default_region) + sts_client = sts_session.client("sts") + assumed_role_object = sts_client.assume_role( + RoleArn=infra_config().firehose_role_arn, + RoleSessionName="AssumeMlLoggingRoleSession", + ) + credentials = assumed_role_object["Credentials"] + session = boto3.Session( + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + firehose_client = session.client("firehose", region_name=infra_config().default_region) + return firehose_client + + def put_record(self, stream_name: str, record: Dict[str, Any]) -> Dict[str, Any]: + """ + Put a record into a Firehose stream. + + Args: + stream_name: The name of the stream. + record: The record to put into the stream. + """ + firehose_response = self._get_firehose_client().put_record( + DeliveryStreamName=stream_name, Record={"Data": json.dumps(record).encode("utf-8")} + ) + if firehose_response["ResponseMetadata"]["HTTPStatusCode"] != 200: + raise StreamPutException( + f"Failed to put record into firehose stream {stream_name}. Response metadata {firehose_response['ResponseMetadata']}." + ) + logger.info( + f"Logged to firehose stream {stream_name}. Record ID: {firehose_response['RecordId']}. Task ID: {record['RESPONSE_BODY']['task_id']}" + ) + return firehose_response diff --git a/server/llm_engine_server/inference/inject_bundle.Dockerfile b/model-engine/model_engine_server/inference/inject_bundle.Dockerfile similarity index 79% rename from server/llm_engine_server/inference/inject_bundle.Dockerfile rename to model-engine/model_engine_server/inference/inject_bundle.Dockerfile index 94432467..84de0bbc 100644 --- a/server/llm_engine_server/inference/inject_bundle.Dockerfile +++ b/model-engine/model_engine_server/inference/inject_bundle.Dockerfile @@ -13,6 +13,6 @@ WORKDIR /app COPY ${LOCAL_BUNDLE_PATH} ${LOCAL_BUNDLE_PATH} -RUN python /app/llm_engine/llm_engine/inference/download_and_inject_bundle.py +RUN python /app/model-engine/model_engine_server/inference/download_and_inject_bundle.py ENV PYTHONPATH /app \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/limits.conf b/model-engine/model_engine_server/inference/limits.conf new file mode 100644 index 00000000..080aa4a3 --- /dev/null +++ b/model-engine/model_engine_server/inference/limits.conf @@ -0,0 +1,2 @@ +modelengine hard nproc 2000 +modelengine soft nproc 1000 diff --git a/model-engine/model_engine_server/inference/post_inference_hooks.py b/model-engine/model_engine_server/inference/post_inference_hooks.py new file mode 100644 index 00000000..7f0e2f4d --- /dev/null +++ b/model-engine/model_engine_server/inference/post_inference_hooks.py @@ -0,0 +1,235 @@ +import json +from abc import ABC, abstractmethod +from datetime import datetime +from typing import Any, Dict, List, Optional, Union + +import pytz +import requests +from fastapi.responses import JSONResponse +from model_engine_server.common.constants import ( + CALLBACK_POST_INFERENCE_HOOK, + LOGGING_POST_INFERENCE_HOOK, +) +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.entities import CallbackAuth, CallbackBasicAuth +from model_engine_server.domain.entities.model_endpoint_entity import ModelEndpointType +from model_engine_server.domain.exceptions import StreamPutException +from model_engine_server.inference.domain.gateways.inference_monitoring_metrics_gateway import ( + InferenceMonitoringMetricsGateway, +) +from model_engine_server.inference.domain.gateways.streaming_storage_gateway import ( + StreamingStorageGateway, +) +from tenacity import Retrying, stop_after_attempt, wait_exponential + +logger = make_logger(logger_name()) + + +class PostInferenceHook(ABC): + def __init__( + self, + endpoint_name: str, + bundle_name: str, + user_id: str, + ): + self._endpoint_name = endpoint_name + self._bundle_name = bundle_name + self._user_id = user_id + + @abstractmethod + def handle( + self, + request_payload: EndpointPredictV1Request, + response: Dict[str, Any], + task_id: Optional[str], + ): + pass + + +class CallbackHook(PostInferenceHook): + def __init__( + self, + endpoint_name: str, + bundle_name: str, + user_id: str, + default_callback_url: Optional[str], + default_callback_auth: Optional[CallbackAuth], + ): + super().__init__(endpoint_name, bundle_name, user_id) + self._default_callback_url = default_callback_url + self._default_callback_auth = default_callback_auth + + def handle( + self, + request_payload: EndpointPredictV1Request, + response: Dict[str, Any], + task_id: Optional[str], + ): + callback_url = request_payload.callback_url + if not callback_url: + callback_url = self._default_callback_url + if not callback_url: + logger.warning("No callback URL specified for request.") + return + + response["task_id"] = task_id + auth = request_payload.callback_auth or self._default_callback_auth + if auth and isinstance(auth.root, CallbackBasicAuth): + auth_tuple = (auth.root.username, auth.root.password) + else: + auth_tuple = (self._user_id, "") + + for attempt in Retrying(stop=stop_after_attempt(3), wait=wait_exponential()): + with attempt: + res = requests.post(url=callback_url, json=response, auth=auth_tuple) + assert 200 <= res.status_code < 300 + + +class LoggingHook(PostInferenceHook): + def __init__( + self, + endpoint_name: str, + bundle_name: str, + user_id: str, + endpoint_id: Optional[str], + endpoint_type: Optional[ModelEndpointType], + bundle_id: Optional[str], + labels: Optional[Dict[str, str]], + streaming_storage_gateway: StreamingStorageGateway, + ): + super().__init__(endpoint_name, bundle_name, user_id) + self._endpoint_id = endpoint_id + self._endpoint_type = endpoint_type + self._bundle_id = bundle_id + self._labels = labels + self._streaming_storage_gateway = streaming_storage_gateway + + def handle( + self, + request_payload: EndpointPredictV1Request, + response: Dict[str, Any], + task_id: Optional[str], + ): + if ( + not self._endpoint_id + or not self._endpoint_type + or not self._bundle_id + or not self._labels + ): + logger.warning( + "No endpoint_id, endpoint_type, bundle_id, or labels specified for request." + ) + return + response["task_id"] = task_id + data_record = { + "EMITTED_AT": datetime.now(pytz.timezone("UTC")).strftime("%Y-%m-%dT%H:%M:%S"), + "REQUEST_BODY": request_payload.json(), + "RESPONSE_BODY": response, + "ENDPOINT_ID": self._endpoint_id, + "ENDPOINT_NAME": self._endpoint_name, + "ENDPOINT_TYPE": self._endpoint_type.value, + "BUNDLE_ID": self._bundle_id, + "LABELS": self._labels, + } + try: # pragma: no cover + json_string = json.dumps(data_record) # pragma: no cover + # Check for unexpected double quotes or escape characters + import re # pragma: no cover + + pattern = r'\\[ntrbfv\'"]|["\']' # pragma: no cover + matches = re.findall(pattern, repr(json_string)) # pragma: no cover + if matches: # pragma: no cover + logger.info( # pragma: no cover + "The JSON string contains double quotes or escape characters.", + extra={"json_string": json_string, "matches": matches}, + ) + else: + logger.info("The JSON string is valid.") # pragma: no cover + except (TypeError, ValueError) as e: # pragma: no cover + logger.warning( + f"Error: The data_record object is not a valid JSON object. {e}" + ) # pragma: no cover + + stream_name = infra_config().firehose_stream_name + if stream_name is None: + logger.warning("No firehose stream name specified. Logging hook will not be executed.") + return + streaming_storage_response = {} # pragma: no cover + try: + streaming_storage_response = ( + self._streaming_storage_gateway.put_record( # pragma: no cover + stream_name=stream_name, record=data_record + ) + ) + except StreamPutException: # pragma: no cover + logger.error( # pragma: no cover + f"Failed to put record into firehose stream {stream_name}. Response metadata {streaming_storage_response.get('ResponseMetadata')}." + ) + + +class PostInferenceHooksHandler: + def __init__( + self, + endpoint_name: str, + bundle_name: str, + user_id: str, + billing_queue: str, + billing_tags: Dict[str, Any], + default_callback_url: Optional[str], + default_callback_auth: Optional[CallbackAuth], + post_inference_hooks: Optional[List[str]], + monitoring_metrics_gateway: InferenceMonitoringMetricsGateway, + endpoint_id: Optional[str], + endpoint_type: Optional[ModelEndpointType], + bundle_id: Optional[str], + labels: Optional[Dict[str, str]], + streaming_storage_gateway: StreamingStorageGateway, + ): + self._monitoring_metrics_gateway = monitoring_metrics_gateway + self._hooks: Dict[str, PostInferenceHook] = {} + if post_inference_hooks: + for hook in post_inference_hooks: + # TODO: Ensure that this process gracefully handles errors in + # initializing each post-inference hook. + hook_lower = hook.lower() + if hook_lower == CALLBACK_POST_INFERENCE_HOOK: + self._hooks[hook_lower] = CallbackHook( + endpoint_name, + bundle_name, + user_id, + default_callback_url, + default_callback_auth, + ) + elif hook_lower == LOGGING_POST_INFERENCE_HOOK: + self._hooks[hook_lower] = LoggingHook( + endpoint_name, + bundle_name, + user_id, + endpoint_id, + endpoint_type, + bundle_id, + labels, + streaming_storage_gateway, + ) + else: + raise ValueError(f"Hook {hook_lower} is currently not supported.") + + def handle( + self, + request_payload: EndpointPredictV1Request, + response: Union[Dict[str, Any], JSONResponse], + task_id: Optional[str] = None, + ): + if isinstance(response, JSONResponse): + loaded_response = json.loads(response.body) + else: + loaded_response = response + for hook_name, hook in self._hooks.items(): + self._monitoring_metrics_gateway.emit_attempted_post_inference_hook(hook_name) + try: + hook.handle(request_payload, loaded_response, task_id) # pragma: no cover + self._monitoring_metrics_gateway.emit_successful_post_inference_hook(hook_name) + except Exception: + logger.exception(f"Hook {hook_name} failed.") diff --git a/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile b/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile new file mode 100644 index 00000000..459fdaee --- /dev/null +++ b/model-engine/model_engine_server/inference/pytorch_or_tf.base.Dockerfile @@ -0,0 +1,69 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +WORKDIR /app + +# Install basic packages. +# TODO: ffmpeg, libsm6, and lixext6 are essentially hardcoded from lidar. +# It's probably more correct to add support for arbitrary user-specified base images, +# otherwise this base image gets bloated over time. +RUN apt-get update && DEBIAN_FRONTEND=noninteractive apt-get install -y \ + apt-utils \ + dumb-init \ + git \ + ssh \ + emacs-nox \ + htop \ + iftop \ + vim \ + ffmpeg \ + libsm6 \ + libxext6 \ + libcurl4-openssl-dev \ + libssl-dev \ + python3-dev \ + gcc \ + build-essential \ + && rm -rf /var/lib/apt/lists/* + +# Apparently wget has a vulnerability so we remove it here +RUN dpkg -l | grep wget && apt-get remove wget -y || echo "wget not installed, skipping removal" + +# Create a virtualenv for python so we install our packages in the right place +# Not sure how useful the existing contents of the pytorch image are anymore :/ Maybe it's used for cuda/cudnn installs +RUN python3 -m venv /venv +ENV PATH=/venv/bin:$PATH + +# Run everything as not-root user +RUN useradd -m modelengine -s /bin/bash +RUN chown -R modelengine /venv +RUN chown -R modelengine /app +# Limits for nproc and consequently number of files open +ADD model-engine/model_engine_server/inference/limits.conf /etc/security/limits.conf +USER modelengine + +# Not good for layer caching oh well +# The inference code should only need these few files/directories to function (hopefully) +# Don't copy the entire folder for security reasons + +RUN mkdir -p /app/model-engine +RUN mkdir -p /app/model-engine/model_engine_server + +RUN chown -R modelengine /app/model-engine + +COPY --chown=modelengine \ + model-engine/model_engine_server/inference/requirements_base.txt \ + /app/model-engine/model_engine_server/inference/requirements_base.txt +RUN pip install -r /app/model-engine/model_engine_server/inference/requirements_base.txt + +COPY --chown=modelengine model-engine/setup.py /app/model-engine/setup.py +COPY --chown=modelengine model-engine/model_engine_server/__init__.py /app/model-engine/model_engine_server/__init__.py +COPY --chown=modelengine model-engine/model_engine_server/common /app/model-engine/model_engine_server/common +COPY --chown=modelengine model-engine/model_engine_server/core /app/model-engine/model_engine_server/core +COPY --chown=modelengine model-engine/model_engine_server/domain /app/model-engine/model_engine_server/domain +COPY --chown=modelengine model-engine/model_engine_server/infra /app/model-engine/model_engine_server/infra +COPY --chown=modelengine model-engine/model_engine_server/inference /app/model-engine/model_engine_server/inference +WORKDIR /app/model-engine +RUN pip install -e . + +WORKDIR /app diff --git a/model-engine/model_engine_server/inference/pytorch_or_tf.user.Dockerfile b/model-engine/model_engine_server/inference/pytorch_or_tf.user.Dockerfile new file mode 100644 index 00000000..29f5cd81 --- /dev/null +++ b/model-engine/model_engine_server/inference/pytorch_or_tf.user.Dockerfile @@ -0,0 +1,10 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +ARG REQUIREMENTS_FILE +COPY --chown=modelengine ${REQUIREMENTS_FILE} /app/model-engine/model_engine_server/inference/requirements.txt +RUN --mount=type=secret,id=codeartifact-pip-conf,target=/etc/pip.conf,mode=0444 \ + PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf \ + pip install -r /app/model-engine/model_engine_server/inference/requirements.txt + +ENV PYTHONPATH /app diff --git a/model-engine/model_engine_server/inference/requirements_base.txt b/model-engine/model_engine_server/inference/requirements_base.txt new file mode 100644 index 00000000..972a5247 --- /dev/null +++ b/model-engine/model_engine_server/inference/requirements_base.txt @@ -0,0 +1,25 @@ +aioredis~=2.0 +urllib3~=1.26.13 +boto3~=1.34.33 +celery[redis,sqs,tblib]==5.3.1 +datadog-api-client==2.11.0 +datadog~=0.47.0 +fastapi~=0.110.0 +# Incompatibility between celery 5 and python 3.7 because of importlib-metadata 5, so we pin it +importlib-metadata<5.0;python_version<"3.8" +scale-launch>=0.1.0 +smart_open==5.1.0 +typing-extensions>=4.1.1 +uvicorn==0.30.6 +waitress==2.0.0 + +# HACK: at time of adding, these deps are imported by model-engine/model_engine_server files +# add here to to prevent `ModuleNotFoundError` error on container startup, these should be in sync with server reqs +# long term: consider having slimmer deps and seperating inference container deps from server container deps +ddtrace==1.8.3 # required for ddtrace-run entrypoint command as well +json-log-formatter~=0.3 # model_engine_server/core/loggers.py +tenacity>=6.0.0,<=6.2.0 # model_engine_server/core/loggers.py +tqdm~=4.64 # model_engine_server/common/service_requests.py +gunicorn~=20.0 +pydantic==2.8.2 + diff --git a/server/llm_engine_server/inference/service_requests.py b/model-engine/model_engine_server/inference/service_requests.py similarity index 88% rename from server/llm_engine_server/inference/service_requests.py rename to model-engine/model_engine_server/inference/service_requests.py index 3dd4485e..795ff91f 100644 --- a/server/llm_engine_server/inference/service_requests.py +++ b/model-engine/model_engine_server/inference/service_requests.py @@ -8,14 +8,14 @@ import boto3 import cloudpickle from celery.result import allow_join_result -from llm_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME -from llm_engine_server.common.errors import UpstreamHTTPSvcError -from llm_engine_server.common.io import open_wrapper -from llm_engine_server.common.service_requests import make_sync_request_with_retries -from llm_engine_server.core.celery import TaskVisibility, celery_app -from llm_engine_server.core.loggers import filename_wo_ext, make_logger +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME +from model_engine_server.common.errors import UpstreamHTTPSvcError +from model_engine_server.common.io import open_wrapper +from model_engine_server.common.service_requests import make_sync_request_with_retries +from model_engine_server.core.celery import TaskVisibility, celery_app +from model_engine_server.core.loggers import logger_name, make_logger -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) # TODO now that we're on SQS this won't work, since it connects to redis s3_bucket: str = os.environ.get("CELERY_S3_BUCKET") # type: ignore @@ -48,8 +48,7 @@ def get_s3_client(): def _read_function_to_network_endpoint_info(): # Dictionary format: {servable_id: {remote: true/false, endpoint_type: "sync"/"async", destination: },...} - # destination is either a celery queue name, i.e. llm_engine_server., or the full url for an http request, - # i.e. http://.ml-internal.scale.com/predict. + # destination is either a celery queue name, i.e. launch., or the full url for an http request. details_json = os.getenv("CHILD_FN_INFO") if details_json is None: return None @@ -62,7 +61,7 @@ def _read_function_to_network_endpoint_info(): def make_request(servable_id: str, local_fn: Callable, args: List[Any], kwargs: Dict[str, Any]): # This is the external-facing entrypoint. Reads in details and decides to make a network request or not - # This function gets imported and called by the LLMEngine client. + # This function gets imported and called by the Launch client. current_fn_info = child_fn_info[servable_id] use_remote = current_fn_info["remote"] if use_remote: diff --git a/server/llm_engine_server/service_builder/__init__.py b/model-engine/model_engine_server/inference/sync_inference/__init__.py similarity index 100% rename from server/llm_engine_server/service_builder/__init__.py rename to model-engine/model_engine_server/inference/sync_inference/__init__.py diff --git a/server/llm_engine_server/inference/sync_inference/constants.py b/model-engine/model_engine_server/inference/sync_inference/constants.py similarity index 100% rename from server/llm_engine_server/inference/sync_inference/constants.py rename to model-engine/model_engine_server/inference/sync_inference/constants.py diff --git a/server/llm_engine_server/inference/sync_inference/destination_rule.yaml b/model-engine/model_engine_server/inference/sync_inference/destination_rule.yaml similarity index 100% rename from server/llm_engine_server/inference/sync_inference/destination_rule.yaml rename to model-engine/model_engine_server/inference/sync_inference/destination_rule.yaml diff --git a/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py new file mode 100644 index 00000000..aba74bbe --- /dev/null +++ b/model-engine/model_engine_server/inference/sync_inference/fastapi_server.py @@ -0,0 +1,56 @@ +import traceback +from functools import wraps + +from fastapi import FastAPI, HTTPException, Response, status +from model_engine_server.common.concurrency_limiter import MultiprocessingConcurrencyLimiter +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.inference.common import load_predict_fn_or_cls, run_predict +from model_engine_server.inference.sync_inference.constants import ( + CONCURRENCY, + FAIL_ON_CONCURRENCY_LIMIT, + NAME, +) + +logger = make_logger(logger_name()) + + +def with_concurrency_limit(concurrency_limiter: MultiprocessingConcurrencyLimiter): + def _inner(flask_func): + @wraps(flask_func) + def _inner_2(*args, **kwargs): + with concurrency_limiter: + return flask_func(*args, **kwargs) + + return _inner_2 + + return _inner + + +app = FastAPI(title=NAME) +concurrency_limiter = MultiprocessingConcurrencyLimiter(CONCURRENCY, FAIL_ON_CONCURRENCY_LIMIT) + +# How does this interact with threads? +# Analogous to init_worker() inside async_inference +predict_fn = load_predict_fn_or_cls() + + +@app.get("/healthcheck") +@app.get("/healthz") +@app.get("/readyz") +def healthcheck(): + return Response(status_code=status.HTTP_200_OK) + + +@app.post("/predict") +@with_concurrency_limit(concurrency_limiter) +def predict(payload: EndpointPredictV1Request): + """ + Assumption: payload is a JSON with format {"url": , "args": , "returned_pickled": boolean} + Returns: Results of running the predict function on the request url. See `run_predict`. + """ + try: + result = run_predict(predict_fn, payload) + return result + except Exception: + raise HTTPException(status_code=500, detail=dict(traceback=str(traceback.format_exc()))) diff --git a/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py b/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py new file mode 100644 index 00000000..2c93b770 --- /dev/null +++ b/model-engine/model_engine_server/inference/sync_inference/start_fastapi_server.py @@ -0,0 +1,39 @@ +import argparse +import os +import subprocess + +from model_engine_server.inference.common import unset_sensitive_envvars +from model_engine_server.inference.sync_inference.constants import NUM_PROCESSES + +PORT = os.environ["PORT"] + + +def start_server(): + parser = argparse.ArgumentParser() + parser.add_argument("--graceful-timeout", type=int, default=1800) + args, extra_args = parser.parse_known_args() + + # TODO: HTTPS + command = [ + "gunicorn", + "--bind", + f"[::]:{PORT}", + "--timeout", + "1200", + "--keep-alive", + "2", + "--worker-class", + "uvicorn.workers.UvicornWorker", + "--workers", + str(NUM_PROCESSES), + "--graceful-timeout", + str(args.graceful_timeout), + "model_engine_server.inference.sync_inference.fastapi_server:app", + *extra_args, + ] + unset_sensitive_envvars() + subprocess.run(command) + + +if __name__ == "__main__": + start_server() diff --git a/server/llm_engine_server/inference/sync_inference/virtual_service.yaml b/model-engine/model_engine_server/inference/sync_inference/virtual_service.yaml similarity index 100% rename from server/llm_engine_server/inference/sync_inference/virtual_service.yaml rename to model-engine/model_engine_server/inference/sync_inference/virtual_service.yaml diff --git a/server/llm_engine_server/inference/sync_inference/vpa.yaml b/model-engine/model_engine_server/inference/sync_inference/vpa.yaml similarity index 100% rename from server/llm_engine_server/inference/sync_inference/vpa.yaml rename to model-engine/model_engine_server/inference/sync_inference/vpa.yaml diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile b/model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile new file mode 100644 index 00000000..be1fa7e9 --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/Dockerfile @@ -0,0 +1,12 @@ +FROM nvcr.io/nvidia/tritonserver:24.03-trtllm-python-py3 + +COPY requirements.txt /workspace/requirements.txt +WORKDIR /workspace +RUN pip install -r requirements.txt + +# Install s5cmd +RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz +RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz + +COPY launch_triton_server.py /workspace/launch_triton_server.py +COPY triton_model_repo /workspace/model_repo \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/README.md b/model-engine/model_engine_server/inference/tensorrt-llm/README.md new file mode 100644 index 00000000..0468de7d --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/README.md @@ -0,0 +1,14 @@ +# Preparing the model weights/tokenizers + +Our TensorRT-LLM docker image expects weights to live in s3/other blob store with the following directory structure: + +root/ + model_tokenizer/ + + model_weights/ + config.json + rank.engine + +You can obtain `model_weights` by building a TRT-LLM engine via the directions found on Nvidia's site (e.g. https://github.com/NVIDIA/TensorRT-LLM/blob/main/README.md#installation, https://github.com/NVIDIA/TensorRT-LLM/blob/v0.8.0/examples/llama/convert_checkpoint.py) + +The inference image is built via the Dockerfile in the same directory as this readme. \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py b/model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py new file mode 100644 index 00000000..1a3434ee --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/launch_triton_server.py @@ -0,0 +1,41 @@ +import argparse +import subprocess +from pathlib import Path + + +def parse_arguments(): + parser = argparse.ArgumentParser() + parser.add_argument( + "--world_size", type=int, default=1, help="world size, only support tensor parallelism now" + ) + parser.add_argument("--tritonserver", type=str, default="/opt/tritonserver/bin/tritonserver") + parser.add_argument( + "--http-address", + type=str, + default="ipv6:[::1]", + help="Default HTTP address to ipv6:[::1].", + ) + parser.add_argument( + "--http-port", + type=int, + default=5005, + help="Default HTTP port to 5005. See llm-engine/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml", + ) + path = str(Path(__file__).parent.absolute()) + "/../all_models/gpt" + parser.add_argument("--model_repo", type=str, default=path) + return parser.parse_args() + + +def get_cmd(world_size, tritonserver, model_repo, http_address, http_port): + cmd = "mpirun --allow-run-as-root " + for i in range(world_size): + cmd += f" -n 1 {tritonserver} --model-repository={model_repo} --http-address {http_address} --http-port {http_port} --disable-auto-complete-config --backend-config=python,shm-region-prefix-name=prefix{i}_ : " + return cmd + + +if __name__ == "__main__": + args = parse_arguments() + cmd = get_cmd( + int(args.world_size), args.tritonserver, args.model_repo, args.http_address, args.http_port + ) + subprocess.call(cmd, shell=True) diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt b/model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt new file mode 100644 index 00000000..7d75f3fc --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/requirements.txt @@ -0,0 +1,3 @@ +sentencepiece==0.1.99 +protobuf==4.24.4 +torch==2.2.2 \ No newline at end of file diff --git a/server/tests/__init__.py b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/1/.tmp similarity index 100% rename from server/tests/__init__.py rename to model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/1/.tmp diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt new file mode 100644 index 00000000..55a52eaf --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/ensemble/config.pbtxt @@ -0,0 +1,444 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "ensemble" +platform: "ensemble" +max_batch_size: 128 +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "max_tokens" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "bad_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "stop_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "end_id" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "pad_id" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_k" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "length_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "min_length" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "frequency_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + optional: true + }, + { + name: "return_log_probs" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "return_context_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "return_generation_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "beam_width" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "prompt_embedding_table" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + }, + { + name: "prompt_vocab_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "embedding_bias_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "embedding_bias_weights" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + } +] +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "cum_log_probs" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "output_log_probs" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "context_logits" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "generation_logits" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] + } +] +ensemble_scheduling { + step [ + { + model_name: "preprocessing" + model_version: -1 + input_map { + key: "QUERY" + value: "text_input" + } + input_map { + key: "REQUEST_OUTPUT_LEN" + value: "max_tokens" + } + input_map { + key: "BAD_WORDS_DICT" + value: "bad_words" + } + input_map { + key: "STOP_WORDS_DICT" + value: "stop_words" + } + input_map { + key: "EMBEDDING_BIAS_WORDS" + value: "embedding_bias_words" + } + input_map { + key: "EMBEDDING_BIAS_WEIGHTS" + value: "embedding_bias_weights" + } + input_map { + key: "END_ID" + value: "end_id" + } + input_map { + key: "PAD_ID" + value: "pad_id" + } + output_map { + key: "REQUEST_INPUT_LEN" + value: "_REQUEST_INPUT_LEN" + } + output_map { + key: "INPUT_ID" + value: "_INPUT_ID" + } + output_map { + key: "REQUEST_OUTPUT_LEN" + value: "_REQUEST_OUTPUT_LEN" + } + output_map { + key: "STOP_WORDS_IDS" + value: "_STOP_WORDS_IDS" + } + output_map { + key: "BAD_WORDS_IDS" + value: "_BAD_WORDS_IDS" + } + output_map { + key: "EMBEDDING_BIAS" + value: "_EMBEDDING_BIAS" + } + output_map { + key: "OUT_END_ID" + value: "_PREPROCESSOR_END_ID" + } + output_map { + key: "OUT_PAD_ID" + value: "_PREPROCESSOR_PAD_ID" + } + }, + { + model_name: "tensorrt_llm" + model_version: -1 + input_map { + key: "input_ids" + value: "_INPUT_ID" + } + input_map { + key: "input_lengths" + value: "_REQUEST_INPUT_LEN" + } + input_map { + key: "request_output_len" + value: "_REQUEST_OUTPUT_LEN" + } + input_map { + key: "end_id" + value: "_PREPROCESSOR_END_ID" + } + input_map { + key: "pad_id" + value: "_PREPROCESSOR_PAD_ID" + } + input_map { + key: "embedding_bias" + value: "_EMBEDDING_BIAS" + } + input_map { + key: "runtime_top_k" + value: "top_k" + } + input_map { + key: "runtime_top_p" + value: "top_p" + } + input_map { + key: "temperature" + value: "temperature" + } + input_map { + key: "len_penalty" + value: "length_penalty" + } + input_map { + key: "repetition_penalty" + value: "repetition_penalty" + } + input_map { + key: "min_length" + value: "min_length" + } + input_map { + key: "presence_penalty" + value: "presence_penalty" + } + input_map { + key: "frequency_penalty" + value: "frequency_penalty" + } + input_map { + key: "random_seed" + value: "random_seed" + } + input_map { + key: "return_log_probs" + value: "return_log_probs" + } + input_map { + key: "return_context_logits" + value: "return_context_logits" + } + input_map { + key: "return_generation_logits" + value: "return_generation_logits" + } + input_map { + key: "beam_width" + value: "beam_width" + } + input_map { + key: "streaming" + value: "stream" + } + input_map { + key: "prompt_embedding_table" + value: "prompt_embedding_table" + } + input_map { + key: "prompt_vocab_size" + value: "prompt_vocab_size" + } + input_map { + key: "stop_words_list" + value: "_STOP_WORDS_IDS" + } + input_map { + key: "bad_words_list" + value: "_BAD_WORDS_IDS" + } + output_map { + key: "output_ids" + value: "_TOKENS_BATCH" + } + output_map { + key: "sequence_length" + value: "_SEQUENCE_LENGTH" + }, + output_map { + key: "cum_log_probs" + value: "_CUM_LOG_PROBS" + } + output_map { + key: "output_log_probs" + value: "_OUTPUT_LOG_PROBS" + }, + output_map { + key: "context_logits" + value: "_CONTEXT_LOGITS" + }, + output_map { + key: "generation_logits" + value: "_GENERATION_LOGITS" + } + }, + { + model_name: "postprocessing" + model_version: -1 + input_map { + key: "TOKENS_BATCH" + value: "_TOKENS_BATCH" + } + input_map { + key: "CUM_LOG_PROBS" + value: "_CUM_LOG_PROBS" + } + input_map { + key: "OUTPUT_LOG_PROBS" + value: "_OUTPUT_LOG_PROBS" + } + input_map { + key: "CONTEXT_LOGITS" + value: "_CONTEXT_LOGITS" + } + input_map { + key: "GENERATION_LOGITS" + value: "_GENERATION_LOGITS" + } + input_map { + key: "SEQUENCE_LENGTH" + value: "_SEQUENCE_LENGTH" + } + output_map { + key: "OUTPUT" + value: "text_output" + } + output_map { + key: "OUT_OUTPUT_LOG_PROBS" + value: "output_log_probs" + } + output_map { + key: "OUT_CUM_LOG_PROBS" + value: "cum_log_probs" + } + output_map { + key: "OUT_CONTEXT_LOGITS" + value: "context_logits" + } + output_map { + key: "OUT_GENERATION_LOGITS" + value: "generation_logits" + } + } + ] +} diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py new file mode 100644 index 00000000..c1c6353b --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/1/model.py @@ -0,0 +1,201 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json + +import numpy as np +import triton_python_backend_utils as pb_utils +from transformers import SPIECE_UNDERLINE, AutoTokenizer, LlamaTokenizer, T5Tokenizer + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args["model_config"]) + tokenizer_dir = model_config["parameters"]["tokenizer_dir"]["string_value"] + tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] + self.skip_special_tokens = model_config["parameters"].get( + "skip_special_tokens", {"string_value": "true"} + )["string_value"].lower() in ["true", "1", "t", "y", "yes"] + + if tokenizer_type == "t5": + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") + elif tokenizer_type == "auto": + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_dir, padding_side="left", trust_remote_code=True + ) + elif tokenizer_type == "llama": + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side="left" + ) + else: + raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}") + self.tokenizer.pad_token = self.tokenizer.eos_token + + # Parse model output configs + output_config = pb_utils.get_output_config_by_name(model_config, "OUTPUT") + + # Convert Triton types to numpy types + self.output_dtype = pb_utils.triton_string_to_numpy(output_config["data_type"]) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + for idx, request in enumerate(requests): + # Get input tensors + tokens_batch = pb_utils.get_input_tensor_by_name(request, "TOKENS_BATCH").as_numpy() + + # Get sequence length + sequence_lengths = pb_utils.get_input_tensor_by_name( + request, "SEQUENCE_LENGTH" + ).as_numpy() + + # Get cum log probs + cum_log_probs = pb_utils.get_input_tensor_by_name(request, "CUM_LOG_PROBS").as_numpy() + + # Get sequence length + output_log_probs = pb_utils.get_input_tensor_by_name( + request, "OUTPUT_LOG_PROBS" + ).as_numpy() + + # Get context logits + context_logits = pb_utils.get_input_tensor_by_name(request, "CONTEXT_LOGITS").as_numpy() + + # Get generation logits + generation_logits = pb_utils.get_input_tensor_by_name( + request, "GENERATION_LOGITS" + ).as_numpy() + + # Reshape Input + # tokens_batch = tokens_batch.reshape([-1, tokens_batch.shape[0]]) + # tokens_batch = tokens_batch.T + + # Postprocessing output data. + outputs = self._postprocessing(tokens_batch, sequence_lengths) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + output_tensor = pb_utils.Tensor("OUTPUT", np.array(outputs).astype(self.output_dtype)) + + out_cum_log_probs = pb_utils.Tensor("OUT_CUM_LOG_PROBS", cum_log_probs) + + out_output_log_probs = pb_utils.Tensor("OUT_OUTPUT_LOG_PROBS", output_log_probs) + + out_context_logits = pb_utils.Tensor("OUT_CONTEXT_LOGITS", context_logits) + + out_generation_logits = pb_utils.Tensor("OUT_GENERATION_LOGITS", generation_logits) + + # Create InferenceResponse. You can set an error here in case + # there was a problem with handling this inference request. + # Below is an example of how you can set errors in inference + # response: + # + # pb_utils.InferenceResponse( + # output_tensors=..., TritonError("An error occurred")) + inference_response = pb_utils.InferenceResponse( + output_tensors=[ + output_tensor, + out_cum_log_probs, + out_output_log_probs, + out_context_logits, + out_generation_logits, + ] + ) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print("Cleaning up...") + + def _postprocessing(self, tokens_batch, sequence_lengths): + outputs = [] + for batch_idx, beam_tokens in enumerate(tokens_batch): + for beam_idx, tokens in enumerate(beam_tokens): + seq_len = sequence_lengths[batch_idx][beam_idx] + output = self.tokenizer.decode( + tokens[:seq_len], skip_special_tokens=self.skip_special_tokens + ) + # Adapted from https://github.com/triton-inference-server/tensorrtllm_backend/pull/423 + # This is somewhat of a hack: add a space before the output if the first token starts with a space + # This may add a space in front of the first token though when we don't want it. + if seq_len > 0: + token_id_string = self.tokenizer.convert_ids_to_tokens( + tokens[:1], skip_special_tokens=self.skip_special_tokens + ) + if ( + len(token_id_string) > 0 + and len(token_id_string[0]) > 0 + and token_id_string[0][0] == SPIECE_UNDERLINE + ): + output = " " + output + outputs.append(output.encode("utf8")) + return outputs diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt new file mode 100644 index 00000000..93af4eec --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/postprocessing/config.pbtxt @@ -0,0 +1,118 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "postprocessing" +backend: "python" +max_batch_size: 128 +input [ + { + name: "TOKENS_BATCH" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + }, + { + name: "SEQUENCE_LENGTH" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "CUM_LOG_PROBS" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "OUTPUT_LOG_PROBS" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "CONTEXT_LOGITS" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + optional: true + }, + { + name: "GENERATION_LOGITS" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] + optional: true + } +] +output [ + { + name: "OUTPUT" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "OUT_CUM_LOG_PROBS" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "OUT_OUTPUT_LOG_PROBS" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "OUT_CONTEXT_LOGITS" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "OUT_GENERATION_LOGITS" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "model_tokenizer" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "llama" + } +} + +parameters { + key: "skip_special_tokens" + value: { + string_value: "True" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py new file mode 100644 index 00000000..ea2d4789 --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/1/model.py @@ -0,0 +1,348 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +from typing import List + +import numpy as np +import triton_python_backend_utils as pb_utils +from transformers import AutoTokenizer, LlamaTokenizer, T5Tokenizer + + +class TritonPythonModel: + """Your Python model must use the same class name. Every Python model + that is created must have "TritonPythonModel" as the class name. + """ + + def initialize(self, args): + """`initialize` is called only once when the model is being loaded. + Implementing `initialize` function is optional. This function allows + the model to initialize any state associated with this model. + Parameters + ---------- + args : dict + Both keys and values are strings. The dictionary keys and values are: + * model_config: A JSON string containing the model configuration + * model_instance_kind: A string containing model instance kind + * model_instance_device_id: A string containing model instance device ID + * model_repository: Model repository path + * model_version: Model version + * model_name: Model name + """ + # Parse model configs + model_config = json.loads(args["model_config"]) + tokenizer_dir = model_config["parameters"]["tokenizer_dir"]["string_value"] + tokenizer_type = model_config["parameters"]["tokenizer_type"]["string_value"] + self.add_special_tokens = model_config["parameters"].get( + "add_special_tokens", {"string_value": "false"} + )["string_value"].lower() in ["true", "1", "t", "y", "yes"] + + if tokenizer_type == "t5": + self.tokenizer = T5Tokenizer(vocab_file=tokenizer_dir, padding_side="left") + elif tokenizer_type == "auto": + self.tokenizer = AutoTokenizer.from_pretrained( + tokenizer_dir, padding_side="left", trust_remote_code=True + ) + elif tokenizer_type == "llama": + self.tokenizer = LlamaTokenizer.from_pretrained( + tokenizer_dir, legacy=False, padding_side="left" + ) + else: + raise AttributeError(f"Unexpected tokenizer type: {tokenizer_type}") + self.tokenizer.pad_token = self.tokenizer.eos_token + + self.tokenizer_end_id = self.tokenizer.encode( + self.tokenizer.eos_token, add_special_tokens=False + )[0] + self.tokenizer_pad_id = self.tokenizer.encode( + self.tokenizer.pad_token, add_special_tokens=False + )[0] + + # Parse model output configs and convert Triton types to numpy types + output_names = [ + "INPUT_ID", + "REQUEST_INPUT_LEN", + "BAD_WORDS_IDS", + "STOP_WORDS_IDS", + "OUT_END_ID", + "OUT_PAD_ID", + ] + input_names = ["EMBEDDING_BIAS_WORDS", "EMBEDDING_BIAS_WEIGHTS"] + for input_name in input_names: + setattr( + self, + input_name.lower() + "_dtype", + pb_utils.triton_string_to_numpy( + pb_utils.get_input_config_by_name(model_config, input_name)["data_type"] + ), + ) + + for output_name in output_names: + setattr( + self, + output_name.lower() + "_dtype", + pb_utils.triton_string_to_numpy( + pb_utils.get_output_config_by_name(model_config, output_name)["data_type"] + ), + ) + + def execute(self, requests): + """`execute` must be implemented in every Python model. `execute` + function receives a list of pb_utils.InferenceRequest as the only + argument. This function is called when an inference is requested + for this model. Depending on the batching configuration (e.g. Dynamic + Batching) used, `requests` may contain multiple requests. Every + Python model, must create one pb_utils.InferenceResponse for every + pb_utils.InferenceRequest in `requests`. If there is an error, you can + set the error argument when creating a pb_utils.InferenceResponse. + Parameters + ---------- + requests : list + A list of pb_utils.InferenceRequest + Returns + ------- + list + A list of pb_utils.InferenceResponse. The length of this list must + be the same as `requests` + """ + + responses = [] + + # Every Python backend must iterate over everyone of the requests + # and create a pb_utils.InferenceResponse for each of them. + logger = pb_utils.Logger + for idx, request in enumerate(requests): + # Get input tensors + query = pb_utils.get_input_tensor_by_name(request, "QUERY").as_numpy() + batch_dim = query.shape[0] + if batch_dim != 1: + + err_str = "Inflight batching backend expects requests with batch size of 1." + logger.log_error(err_str) + responses.append( + pb_utils.InferenceResponse( + output_tensors=[], error=pb_utils.TritonError(err_str) + ) + ) + continue + + request_output_len = pb_utils.get_input_tensor_by_name( + request, "REQUEST_OUTPUT_LEN" + ).as_numpy() + + bad_words_dict = pb_utils.get_input_tensor_by_name(request, "BAD_WORDS_DICT") + if bad_words_dict is not None: + bad_words_dict = bad_words_dict.as_numpy() + + stop_words_dict = pb_utils.get_input_tensor_by_name(request, "STOP_WORDS_DICT") + if stop_words_dict is not None: + stop_words_dict = stop_words_dict.as_numpy() + + embedding_bias_words = pb_utils.get_input_tensor_by_name( + request, "EMBEDDING_BIAS_WORDS" + ) + if embedding_bias_words is not None: + embedding_bias_words = embedding_bias_words.as_numpy() + + embedding_bias_weights = pb_utils.get_input_tensor_by_name( + request, "EMBEDDING_BIAS_WEIGHTS" + ) + if embedding_bias_weights is not None: + embedding_bias_weights = embedding_bias_weights.as_numpy() + + # Take the end_id from the input tensors + # If not specified, use tokenizer to get end_id + end_id = pb_utils.get_input_tensor_by_name(request, "END_ID") + if end_id is not None: + end_id = end_id.as_numpy() + else: + end_id = [[self.tokenizer_end_id]] + + # Take the pad_id from the input tensors + # If not specified, use tokenizer to get pad_id + pad_id = pb_utils.get_input_tensor_by_name(request, "PAD_ID") + if pad_id is not None: + pad_id = pad_id.as_numpy() + else: + pad_id = [[self.tokenizer_pad_id]] + + # Preprocessing input data. + input_id, request_input_len = self._create_request(query) + bad_words = self._to_word_list_format(bad_words_dict) + stop_words = self._to_word_list_format(stop_words_dict) + + embedding_bias = self._get_embedding_bias( + embedding_bias_words, embedding_bias_weights, self.embedding_bias_weights_dtype + ) + + # Create output tensors. You need pb_utils.Tensor + # objects to create pb_utils.InferenceResponse. + input_id_tensor = pb_utils.Tensor("INPUT_ID", input_id.astype(self.input_id_dtype)) + request_input_len_tensor = pb_utils.Tensor( + "REQUEST_INPUT_LEN", request_input_len.astype(self.request_input_len_dtype) + ) + request_output_len_tensor = pb_utils.Tensor("REQUEST_OUTPUT_LEN", request_output_len) + bad_words_ids_tensor = pb_utils.Tensor("BAD_WORDS_IDS", bad_words) + stop_words_ids_tensor = pb_utils.Tensor("STOP_WORDS_IDS", stop_words) + embedding_bias_tensor = pb_utils.Tensor("EMBEDDING_BIAS", embedding_bias) + end_id_tensor = pb_utils.Tensor("OUT_END_ID", np.array(end_id, dtype=np.int32)) + pad_id_tensor = pb_utils.Tensor("OUT_PAD_ID", np.array(pad_id, dtype=np.int32)) + + inference_response = pb_utils.InferenceResponse( + output_tensors=[ + input_id_tensor, + bad_words_ids_tensor, + stop_words_ids_tensor, + request_input_len_tensor, + request_output_len_tensor, + embedding_bias_tensor, + end_id_tensor, + pad_id_tensor, + ] + ) + responses.append(inference_response) + + # You should return a list of pb_utils.InferenceResponse. Length + # of this list must match the length of `requests` list. + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print("Cleaning up...") + + def _create_request(self, query): + """ + query : batch string (2D numpy array) + """ + start_ids = [ + np.array( + self.tokenizer.encode(s[0].decode(), add_special_tokens=self.add_special_tokens) + ).astype(int) + for s in query + ] + start_lengths = np.array([[len(ids)] for ids in start_ids]).astype(int) + + max_len = 0 + for seq in start_ids: + max_len = max(max_len, seq.shape[0]) + start_ids = np.stack( + [ + np.pad( + seq, + (0, max_len - seq.shape[0]), + "constant", + constant_values=(0, self.tokenizer_pad_id), + ) + for seq in start_ids + ] + ) + + return start_ids, start_lengths + + def _to_word_list_format(self, word_lists: List[List[str | bytes]]): + """ + word_lists format: + len(word_lists) == batch_size + word_lists[i] means the words associated to batch item i. A "word" may actually be any string. Like "lorem" or "lorem ipsum". + """ + assert self.tokenizer is not None, "need to set tokenizer" + + if word_lists is None: + # Return an empty array of shape (1,2,0) + return np.empty([1, 2, 0], dtype="int32") + + flat_ids = [] + offsets = [] + for word_list in word_lists: + item_flat_ids = [] + item_offsets = [] + + for word in word_list: + if isinstance(word, bytes): + word = word.decode() + + ids = self.tokenizer.encode(word, add_special_tokens=False) + if len(ids) == 0: + continue + + item_flat_ids += ids + item_offsets.append(len(ids)) + + # Add a case where ids[0] decodes to empty string, then add another set of ids here + # Unfortunately, we don't have access to the entire sequence of returned response tokens when decoding, + # so we have to do what we can to get a reasonable list of token ids corresponding to a stop sequence. + # True correctness would look like figuring out all the ways of decoding a stop sequence, and then + # adding all of them to this item_flat_ids map. + if len(ids) > 1 and self.tokenizer.decode(ids[0]) == "": + new_ids = ids[1:] + item_flat_ids += new_ids + item_offsets.append(len(new_ids)) + + flat_ids.append(np.array(item_flat_ids)) + offsets.append(np.cumsum(np.array(item_offsets))) + + pad_to = max(1, max(len(ids) for ids in flat_ids)) + + for i, (ids, offs) in enumerate(zip(flat_ids, offsets)): + flat_ids[i] = np.pad(ids, (0, pad_to - len(ids)), constant_values=0) + offsets[i] = np.pad(offs, (0, pad_to - len(offs)), constant_values=-1) + + return np.array([flat_ids, offsets], dtype="int32").transpose((1, 0, 2)) + + def _get_embedding_bias(self, embedding_bias_words, embedding_bias_weights, bias_dtype): + + assert self.tokenizer is not None, "need to set tokenizer" + + if embedding_bias_words is None or embedding_bias_weights is None: + return np.empty([1, 0], dtype=self.embedding_bias_weights_dtype) + + batch_embedding_bias = [] + for words, weights in zip(embedding_bias_words, embedding_bias_weights): + + vocab_size = self.tokenizer.vocab_size + embedding_bias = [0.0] * vocab_size + + assert len(words) == len( + weights + ), "Embedding bias words must have same dimension as embedding bias weights" + + for word, weight in zip(words, weights): + if isinstance(word, bytes): + word = word.decode() + ids = self.tokenizer.encode(word) + + if len(ids) == 0: + continue + + for id in ids: + embedding_bias[id] += weight + + batch_embedding_bias.append(np.array(embedding_bias)) + + return np.array(batch_embedding_bias, dtype=bias_dtype) diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt new file mode 100644 index 00000000..3a77e264 --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/preprocessing/config.pbtxt @@ -0,0 +1,147 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "preprocessing" +backend: "python" +max_batch_size: 128 +input [ + { + name: "QUERY" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "BAD_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "STOP_WORDS_DICT" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "EMBEDDING_BIAS_WORDS" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "EMBEDDING_BIAS_WEIGHTS" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + }, + { + name: "END_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + }, + { + name: "PAD_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + } +] +output [ + { + name: "INPUT_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "REQUEST_INPUT_LEN" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "BAD_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "STOP_WORDS_IDS" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + }, + { + name: "EMBEDDING_BIAS" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "REQUEST_OUTPUT_LEN" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "OUT_END_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "OUT_PAD_ID" + data_type: TYPE_INT32 + dims: [ -1 ] + } +] + +parameters { + key: "tokenizer_dir" + value: { + string_value: "model_tokenizer" + } +} + +parameters { + key: "tokenizer_type" + value: { + string_value: "llama" + } +} + +parameters { + key: "add_special_tokens" + value: { + string_value: "False" + } +} + +instance_group [ + { + count: 1 + kind: KIND_CPU + } +] diff --git a/server/tests/integration/__init__.py b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/1/.gitkeep similarity index 100% rename from server/tests/integration/__init__.py rename to model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/1/.gitkeep diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt new file mode 100644 index 00000000..f1b466eb --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm/config.pbtxt @@ -0,0 +1,370 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm" +backend: "tensorrtllm" +max_batch_size: 128 + +model_transaction_policy { + decoupled: true +} + +dynamic_batching { + preferred_batch_size: [ 128 ] + max_queue_delay_microseconds: 100000 +} + +input [ + { + name: "input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + allow_ragged_batch: true + }, + { + name: "input_lengths" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + }, + { + name: "request_output_len" + data_type: TYPE_INT32 + dims: [ 1 ] + }, + { + name: "draft_input_ids" + data_type: TYPE_INT32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "end_id" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "pad_id" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "bad_words_list" + data_type: TYPE_INT32 + dims: [ 2, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "embedding_bias" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "beam_width" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_k" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "runtime_top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "len_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "min_length" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "frequency_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_log_probs" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_context_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_generation_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "stop" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "streaming" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "prompt_embedding_table" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + { + name: "prompt_vocab_size" + data_type: TYPE_INT32 + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + # weights for a lora adapter shape [ num_lora_modules_layers, D x Hi + Ho x D ] + # where the last dimension holds the in / out adapter weights for the associated module (e.g. attn_qkv) and model layer + # each of the in / out tensors are first flattened and then concatenated together in the format above. + # D=adapter_size (R value), Hi=hidden_size_in, Ho=hidden_size_out. + { + name: "lora_weights" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + allow_ragged_batch: true + }, + # module identifier (same size a first dimension of lora_weights) + # See LoraModule::ModuleType for model id mapping + # + # "attn_qkv": 0 # compbined qkv adapter + # "attn_q": 1 # q adapter + # "attn_k": 2 # k adapter + # "attn_v": 3 # v adapter + # "attn_dense": 4 # adapter for the dense layer in attention + # "mlp_h_to_4h": 5 # for llama2 adapter for gated mlp layer after attention / RMSNorm: up projection + # "mlp_4h_to_h": 6 # for llama2 adapter for gated mlp layer after attention / RMSNorm: down projection + # "mlp_gate": 7 # for llama2 adapter for gated mlp later after attention / RMSNorm: gate + # + # last dim holds [ module_id, layer_idx, adapter_size (D aka R value) ] + { + name: "lora_config" + data_type: TYPE_INT32 + dims: [ -1, 3 ] + optional: true + allow_ragged_batch: true + } +] +output [ + { + name: "output_ids" + data_type: TYPE_INT32 + dims: [ -1, -1 ] + }, + { + name: "sequence_length" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "cum_log_probs" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "output_log_probs" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "context_logits" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "generation_logits" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] + } +] +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] +parameters: { + key: "max_beam_width" + value: { + string_value: "1" + } +} +parameters: { + key: "FORCE_CPU_ONLY_INPUT_TENSORS" + value: { + string_value: "no" + } +} +parameters: { + key: "gpt_model_type" + value: { + string_value: "inflight_fused_batching" + } +} +parameters: { + key: "gpt_model_path" + value: { + string_value: "./model_weights" + } +} +parameters: { + key: "max_tokens_in_paged_kv_cache" + value: { + string_value: "${max_tokens_in_paged_kv_cache}" + } +} +parameters: { + key: "max_attention_window_size" + value: { + string_value: "${max_attention_window_size}" + } +} +parameters: { + key: "batch_scheduler_policy" + value: { + string_value: "max_utilization" + } +} +parameters: { + key: "kv_cache_free_gpu_mem_fraction" + value: { + string_value: "0.9" + } +} +parameters: { + key: "enable_trt_overlap" + value: { + string_value: "${enable_trt_overlap}" + } +} +parameters: { + key: "exclude_input_in_output" + value: { + string_value: "true" + } +} +parameters: { + key: "enable_kv_cache_reuse" + value: { + string_value: "${enable_kv_cache_reuse}" + } +} +parameters: { + key: "normalize_log_probs" + value: { + string_value: "${normalize_log_probs}" + } +} +parameters: { + key: "enable_chunked_context" + value: { + string_value: "${enable_chunked_context}" + } +} +parameters: { + key: "gpu_device_ids" + value: { + string_value: "${gpu_device_ids}" + } +} diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/1/model.py b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/1/model.py new file mode 100644 index 00000000..545e3a7d --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/1/model.py @@ -0,0 +1,389 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import json +import traceback + +import numpy as np +import triton_python_backend_utils as pb_utils + + +class TritonPythonModel: + def initialize(self, args): + + # Parse model configs + model_config = json.loads(args["model_config"]) + + params = model_config["parameters"] + + accumulate_tokens_str = "" + if "accumulate_tokens" in params: + accumulate_tokens_str = params["accumulate_tokens"]["string_value"] + + self.accumulate_tokens = accumulate_tokens_str.lower() in ["true", "yes", "1", "t"] + + self.decoupled = pb_utils.using_decoupled_model_transaction_policy(model_config) + + self.logger = pb_utils.Logger + + self.bls_input_tensor_names = [ + "text_input", + "max_tokens", + "bad_words", + "stop_words", + "end_id", + "pad_id", + "top_k", + "top_p", + "temperature", + "length_penalty", + "repetition_penalty", + "min_length", + "presence_penalty", + "frequency_penalty", + "random_seed", + "return_log_probs", + "return_context_logits", + "return_generation_logits", + "beam_width", + "stream", + "prompt_embedding_table", + "prompt_vocab_size", + "embedding_bias_words", + "embedding_bias_weights", + ] + + self.preproc_input_to_bls_input_map = { + "QUERY": "text_input", + "REQUEST_OUTPUT_LEN": "max_tokens", + "BAD_WORDS_DICT": "bad_words", + "STOP_WORDS_DICT": "stop_words", + "EMBEDDING_BIAS_WORDS": "embedding_bias_words", + "EMBEDDING_BIAS_WEIGHTS": "embedding_bias_weights", + "END_ID": "end_id", + "PAD_ID": "pad_id", + } + + self.preproc_output_to_trtllm_input_map = { + "INPUT_ID": "input_ids", + "REQUEST_INPUT_LEN": "input_lengths", + "REQUEST_OUTPUT_LEN": "request_output_len", + "BAD_WORDS_IDS": "bad_words_list", + "STOP_WORDS_IDS": "stop_words_list", + "EMBEDDING_BIAS": "embedding_bias", + "OUT_END_ID": "end_id", + "OUT_PAD_ID": "pad_id", + } + + self.trtllm_input_to_bls_input_map = { + "beam_width": "beam_width", + "runtime_top_k": "top_k", + "runtime_top_p": "top_p", + "len_penalty": "length_penalty", + "repetition_penalty": "repetition_penalty", + "min_length": "min_length", + "presence_penalty": "presence_penalty", + "frequency_penalty": "frequency_penalty", + "random_seed": "random_seed", + "return_log_probs": "return_log_probs", + "return_context_logits": "return_context_logits", + "return_generation_logits": "return_generation_logits", + "streaming": "stream", + "prompt_embedding_table": "prompt_embedding_table", + "prompt_vocab_size": "prompt_vocab_size", + } + + self.trtllm_output_to_postproc_input_map = { + "output_ids": "TOKENS_BATCH", + "sequence_length": "SEQUENCE_LENGTH", + "cum_log_probs": "CUM_LOG_PROBS", + "output_log_probs": "OUTPUT_LOG_PROBS", + "context_logits": "CONTEXT_LOGITS", + "generation_logits": "GENERATION_LOGITS", + } + + self.postproc_output_to_bls_output_map = { + "OUTPUT": "text_output", + "OUT_CUM_LOG_PROBS": "cum_log_probs", + "OUT_OUTPUT_LOG_PROBS": "output_log_probs", + "OUT_CONTEXT_LOGITS": "context_logits", + "OUT_GENERATION_LOGITS": "generation_logits", + } + + def _get_bls_input_tensors_map(self, request): + + bls_input_tensors_map = {} + for input_tensor_name in self.bls_input_tensor_names: + tensor = pb_utils.get_input_tensor_by_name(request, input_tensor_name) + if tensor is not None: + bls_input_tensors_map[input_tensor_name] = tensor + + return bls_input_tensors_map + + def _get_preproc_input_tensors(self, bls_input_tensors_map): + + preproc_input_tensors = [] + + for preproc_name, bls_name in self.preproc_input_to_bls_input_map.items(): + + if bls_name in bls_input_tensors_map: + tensor = bls_input_tensors_map[bls_name] + # Change the name to what the preprocessor expects + preproc_input_tensors.append(pb_utils.Tensor(preproc_name, tensor.as_numpy())) + + return preproc_input_tensors + + def _get_trtllm_input_tensors(self, bls_input_tensors_map, preproc_output_tensors): + + trtllm_input_tensors = [] + + # Set input tensors from preprocessor outputs + for preproc_output_tensor in preproc_output_tensors: + + trtllm_tensor_name = self.preproc_output_to_trtllm_input_map[ + preproc_output_tensor.name() + ] + trtllm_input_tensors.append( + pb_utils.Tensor(trtllm_tensor_name, preproc_output_tensor.as_numpy()) + ) + + # Set input tensors from bls inputs + for trtllm_name, bls_name in self.trtllm_input_to_bls_input_map.items(): + + if bls_name in bls_input_tensors_map: + tensor = bls_input_tensors_map[bls_name] + # Change the name to what the preprocessor expects + trtllm_input_tensors.append(pb_utils.Tensor(trtllm_name, tensor.as_numpy())) + + return trtllm_input_tensors + + def _get_postproc_input_tensors(self, tokens, trtllm_output_tensors): + + postproc_input_tensors = [] + + for trtllm_output_tensor in trtllm_output_tensors: + + # If in decoupled mode, option to append new tokens to existing tokens before calling postprocessor + # This might be needed for some tokenizers + # Note that in that case, the client must overwrite previously received output text + if ( + self.accumulate_tokens + and self.decoupled + and trtllm_output_tensor.name() == "output_ids" + ): + + new_tokens = trtllm_output_tensor.as_numpy() + if new_tokens.ndim != 3: + raise pb_utils.TritonModelException( + "Expected output_ids tensor to have 3 dims." + ) + if new_tokens.shape[0] != 1: + raise pb_utils.TritonModelException( + "Expected output_ids tensor to have batch size of 1" + ) + if new_tokens.shape[1] != 1: + raise pb_utils.TritonModelException( + "Accumulation of tokens is only implemented for beam width = 1" + ) + + tokens = ( + new_tokens if (tokens is None) else np.concatenate((tokens, new_tokens), axis=2) + ) + + # output ids + postproc_output_ids_name = self.trtllm_output_to_postproc_input_map["output_ids"] + postproc_input_tensors.append(pb_utils.Tensor(postproc_output_ids_name, tokens)) + + # sequence length + np_seq_len_tensor = np.array([[tokens.shape[2]]], dtype=np.int32) + postproc_seq_len_name = self.trtllm_output_to_postproc_input_map["sequence_length"] + postproc_input_tensors.append( + pb_utils.Tensor(postproc_seq_len_name, np_seq_len_tensor) + ) + + # Set input tensors from trtllm outputs + for trtllm_output_tensor in trtllm_output_tensors: + + # output_ids and sequence_length were handled earlier + if ( + self.accumulate_tokens + and self.decoupled + and ( + trtllm_output_tensor.name() == "output_ids" + or trtllm_output_tensor.name() == "sequence_length" + ) + ): + continue + + postproc_tensor_name = self.trtllm_output_to_postproc_input_map[ + trtllm_output_tensor.name() + ] + + postproc_input_tensors.append( + pb_utils.Tensor(postproc_tensor_name, trtllm_output_tensor.as_numpy()) + ) + + return tokens, postproc_input_tensors + + def _get_bls_output_tensors(self, postproc_output_tensors): + + bls_output_tensors = [] + + # Set input tensors from trtllm outputs + for postproc_output_tensor in postproc_output_tensors: + + bls_tensor_name = self.postproc_output_to_bls_output_map[postproc_output_tensor.name()] + bls_output_tensors.append( + pb_utils.Tensor(bls_tensor_name, postproc_output_tensor.as_numpy()) + ) + + return bls_output_tensors + + def execute(self, requests): + + responses = [] + bls_response_sender = None + + for request in requests: + + # Get the response sender for the BLS + if self.decoupled: + bls_response_sender = request.get_response_sender() + + try: + # Get the bls input tensors + bls_input_tensors_map = self._get_bls_input_tensors_map(request) + + # Check the batch dimension + for name, tensor in bls_input_tensors_map.items(): + batch_dim = tensor.as_numpy().shape[0] + + if batch_dim != 1: + + err_str = "Inflight batching backend expects requests with batch size of 1." + self.logger.log_error(err_str) + raise pb_utils.TritonModelException(err_str) + + # Create the preprocessor input tensors + preproc_input_tensors = self._get_preproc_input_tensors(bls_input_tensors_map) + + preproc_request = pb_utils.InferenceRequest( + model_name="preprocessing", + inputs=preproc_input_tensors, + requested_output_names=list(self.preproc_output_to_trtllm_input_map.keys()), + ) + + # Execute preprocessor + preproc_response = preproc_request.exec() + + if preproc_response.has_error(): + raise pb_utils.TritonModelException(preproc_response.error().message()) + + # Create the trtllm input tensors + trtllm_input_tensors = self._get_trtllm_input_tensors( + bls_input_tensors_map, preproc_response.output_tensors() + ) + + trtllm_request = pb_utils.InferenceRequest( + model_name="tensorrt_llm", + inputs=trtllm_input_tensors, + requested_output_names=list(self.trtllm_output_to_postproc_input_map.keys()), + ) + + # Execute trtllm + trtllm_responses = trtllm_request.exec(decoupled=self.decoupled) + + if not self.decoupled: + trtllm_responses = [trtllm_responses] + + tokens = None + + # Loop over the trtllm responses + for trtllm_response in trtllm_responses: + + if trtllm_response.has_error(): + raise pb_utils.TritonModelException(trtllm_response.error().message()) + + trtllm_output_tensors = trtllm_response.output_tensors() + + tokens, postproc_input_tensors = self._get_postproc_input_tensors( + tokens, trtllm_output_tensors + ) + + postproc_request = pb_utils.InferenceRequest( + model_name="postprocessing", + inputs=postproc_input_tensors, + requested_output_names=list(self.postproc_output_to_bls_output_map.keys()), + ) + + # Execute postprocessor + postproc_response = postproc_request.exec() + + if postproc_response.has_error(): + raise pb_utils.TritonModelException(postproc_response.error().message()) + + # Create the BLS response + bls_output_tensors = self._get_bls_output_tensors( + postproc_response.output_tensors() + ) + + bls_response = pb_utils.InferenceResponse(output_tensors=bls_output_tensors) + + if self.decoupled: + bls_response_sender.send(bls_response) + else: + responses.append(bls_response) + + # All responses have been sent, set final flag + if self.decoupled: + bls_response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + + except Exception: + + self.logger.log_error(traceback.format_exc()) + # If encountering an error, send a response with err msg + error_response = pb_utils.InferenceResponse( + output_tensors=[], error=pb_utils.TritonError(traceback.format_exc()) + ) + + if self.decoupled: + bls_response_sender.send(error_response) + bls_response_sender.send(flags=pb_utils.TRITONSERVER_RESPONSE_COMPLETE_FINAL) + else: + responses.append(error_response) + + if self.decoupled: + return None + else: + assert len(responses) == len(requests) + return responses + + def finalize(self): + """`finalize` is called only once when the model is being unloaded. + Implementing `finalize` function is optional. This function allows + the model to perform any necessary clean ups before exit. + """ + print("Cleaning up...") diff --git a/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/config.pbtxt b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/config.pbtxt new file mode 100644 index 00000000..168c819c --- /dev/null +++ b/model-engine/model_engine_server/inference/tensorrt-llm/triton_model_repo/tensorrt_llm_bls/config.pbtxt @@ -0,0 +1,221 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +name: "tensorrt_llm_bls" +backend: "python" +max_batch_size: 128 + +model_transaction_policy { + decoupled: true +} + +input [ + { + name: "text_input" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "max_tokens" + data_type: TYPE_INT32 + dims: [ -1 ] + }, + { + name: "bad_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "stop_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "end_id" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "pad_id" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_k" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "top_p" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "temperature" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "length_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "repetition_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "min_length" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "presence_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "frequency_penalty" + data_type: TYPE_FP32 + dims: [ 1 ] + optional: true + }, + { + name: "random_seed" + data_type: TYPE_UINT64 + dims: [ 1 ] + optional: true + }, + { + name: "return_log_probs" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "return_context_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "return_generation_logits" + data_type: TYPE_BOOL + dims: [ 1 ] + reshape: { shape: [ ] } + optional: true + }, + { + name: "beam_width" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "stream" + data_type: TYPE_BOOL + dims: [ 1 ] + optional: true + }, + { + name: "prompt_embedding_table" + data_type: TYPE_FP16 + dims: [ -1, -1 ] + optional: true + }, + { + name: "prompt_vocab_size" + data_type: TYPE_INT32 + dims: [ 1 ] + optional: true + }, + { + name: "embedding_bias_words" + data_type: TYPE_STRING + dims: [ -1 ] + optional: true + }, + { + name: "embedding_bias_weights" + data_type: TYPE_FP32 + dims: [ -1 ] + optional: true + } +] +output [ + { + name: "text_output" + data_type: TYPE_STRING + dims: [ -1 ] + }, + { + name: "cum_log_probs" + data_type: TYPE_FP32 + dims: [ -1 ] + }, + { + name: "output_log_probs" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "context_logits" + data_type: TYPE_FP32 + dims: [ -1, -1 ] + }, + { + name: "generation_logits" + data_type: TYPE_FP32 + dims: [ -1, -1, -1 ] + } +] + +parameters: { + key: "accumulate_tokens" + value: { + string_value: "true" + } +} + +instance_group [ + { + count: 1 + kind : KIND_CPU + } +] diff --git a/server/tests/unit/common/__init__.py b/model-engine/model_engine_server/inference/tool_completion/__init__.py similarity index 100% rename from server/tests/unit/common/__init__.py rename to model-engine/model_engine_server/inference/tool_completion/__init__.py diff --git a/model-engine/model_engine_server/inference/tool_completion/base.py b/model-engine/model_engine_server/inference/tool_completion/base.py new file mode 100644 index 00000000..6166987a --- /dev/null +++ b/model-engine/model_engine_server/inference/tool_completion/base.py @@ -0,0 +1,17 @@ +from typing import Optional, Tuple + + +class BaseTool: + """ + Base class for third-party tools. + """ + + tool_context_start = "" + tool_call_token = "" + tool_context_end = "" + + def __call__(self, expression: str, past_context: Optional[str]) -> Tuple[str, int]: + """ + Call method to be overridden by child classes. + """ + raise NotImplementedError("The evaluate method must be implemented by child classes.") diff --git a/model-engine/model_engine_server/inference/tool_completion/tools.py b/model-engine/model_engine_server/inference/tool_completion/tools.py new file mode 100644 index 00000000..8e84bdff --- /dev/null +++ b/model-engine/model_engine_server/inference/tool_completion/tools.py @@ -0,0 +1,249 @@ +import re +import subprocess +from enum import Enum +from typing import Optional, Tuple + +import docker +from model_engine_server.inference.tool_completion.base import BaseTool +from model_engine_server.inference.tool_completion.utils import ( + FIX_ERRORS_MAPPING, + NAME_ERROR_PATTERN, + PRINT_PATTERN, +) +from transformers import LlamaTokenizer + +tokenizer = LlamaTokenizer.from_pretrained("openlm-research/open_llama_7b", legacy=False) +MAX_CODEBLOCK_RETRIES = 3 + + +class CodeBlockEvaluator(BaseTool): # pragma: no cover + """ + A evaluator to "pseudo-safely" execute python code blocks. + Executes code from a model generated response using a safe python interpreter. + the code should have the following format: + + ```python + {code} + ``` + {output} + >>> + + The output will be replaced with the output from executing the code. + """ + + tool_context_start = "```python\n" + tool_call_token = "\n```\n" + tool_context_end = "\n>>>\n" + + @staticmethod + def _cleanup_code_error(error_code: str) -> str: + """This function will clean up an error code from code execution + + Args: + error_code (str): The full error code (e.g. like below): + + Command '['python', '-c', 'import math\nx = 2\nmath.sqrt(y)']' in image 'continuumio/anaconda3' + returned non-zero exit status 1: b'Traceback (most recent call last): + File "", line 3, in \nNameError: name \'y\' is not defined\n' + + Returns: + str: like the following: + + Traceback (most recent call last): File "", line 3, in + NameError: name \'y\' is not defined + + """ + if "Traceback" not in error_code: + return error_code + + # Let's find the byte string: (e.g. b') + stacktrace = error_code.split("b'")[-1] + + # Now read it as a bytestring + stacktrace = "\n" + stacktrace.encode("utf-8").decode("unicode_escape") + + return stacktrace.strip("'") + + def __init__(self): + # Condition to check if we can use docker + try: + self.client = docker.from_env() + self.evaluate = self.evaluate_code_in_docker + except docker.errors.DockerException: + # If docker is not available, use the python interpreter + self.evaluate = self.evaluate_code_using_exec + + def __call__( + self, + expression: str, + past_context: Optional[str] = None, + ) -> Tuple[str, int]: + """ + Given an expression, extract the code block and execute it using a safe python interpreter. Additionally, + approximate the number of tokens added to the expression from the tool output along with handling retries + due to simple tool errors (e.g. import errors, missing variables) + + Args: + expression (str): text with natural language and code blocks + past_context (Optional[str]): previously generated code blocks for retrying simple code errors + + Returns: + str: Formatted output from the code execution tool + int: Number of tokens added + + Raises: + RuntimeError: If any errors occur during the code execution or retries for simple code errors. + """ + tool_output = "" + expression_ = expression + num_tokens = 0 + if (CodeBlockEvaluator.tool_context_start in expression) and ( + CodeBlockEvaluator.tool_call_token in expression + ): + # Extract the expression between the start token and the special token for the tool to evaluate + code_expression = expression.split(CodeBlockEvaluator.tool_context_start)[-1].split( + CodeBlockEvaluator.tool_call_token + )[0] + + # Note: Can increase max retries if needed (e.g. > 1 import errors + variable not defined in code_expression) + for retry_count in range(MAX_CODEBLOCK_RETRIES): + try: + tool_output = self.evaluate(code_expression) + break + except Exception as e: + name_error = re.search(NAME_ERROR_PATTERN, str(e)) + if ( + past_context is None + or name_error is None + or retry_count == MAX_CODEBLOCK_RETRIES - 1 + ): + error_code = self._cleanup_code_error(str(e)) + raise RuntimeError(f"failed with error: {error_code}") + + if retry_count == 0 and past_context != "": + # Grab all the prior code blocks in "```python\n{code}\n```\n" format + code_expression = ( + self._extract_code_blocks(past_context) + "\n" + code_expression + ) + else: + current_error = name_error.group(1).replace("\\", "") + # Make sure error is one of the fixable/common import errors seen in the past + if current_error not in FIX_ERRORS_MAPPING.keys(): + error_code = self._cleanup_code_error(str(e)) + raise RuntimeError( + f"failed on retry: {retry_count}, NameError variable: {current_error}, and error: {error_code}" + ) + + code_expression = FIX_ERRORS_MAPPING[current_error] + "\n" + code_expression + + tool_output = ( + CodeBlockEvaluator.tool_call_token + + tool_output + + CodeBlockEvaluator.tool_context_end + ) + + expression_ = expression.split(CodeBlockEvaluator.tool_call_token)[0] + tool_output + num_tokens = max( + 0, len(tokenizer(expression_).input_ids) - len(tokenizer(expression).input_ids) + ) + return expression_, num_tokens + + def _extract_code_blocks(self, context: str): + """ + Given some text (e.g. previous completion), extract all the code blocks in the format + along with removing any old print statements. + + Args: + context (str): text with natural language and code blocks + + Returns: + str: Parsed code blocks with print statements removed + """ + code_block_pattern = re.compile( + rf"{CodeBlockEvaluator.tool_context_start}(.*?){CodeBlockEvaluator.tool_call_token}", + re.DOTALL, + ) + code_block_matches = code_block_pattern.findall(context) + # Remove lines with print statements bc already included in model response + cleaned_code_blocks = [] + for code_block in code_block_matches: + no_print_code_blocks = [] + for line in code_block.split("\n"): + # Ignore lines with print format + if re.search(PRINT_PATTERN, line) is None: + no_print_code_blocks.append(line) + cleaned_code_blocks.append("\n".join(no_print_code_blocks)) + return "\n".join(cleaned_code_blocks) + + def evaluate_code_in_docker(self, code: str) -> str: + """ + Executes a block of code using a safe python interpreter and returns the output as a string. + + This function uses a docker container to safely execute a given block of code. + The function returns the output of the last executed line, if any. + + Args: + code (str): A string containing the Python code to be executed. + + Returns: + str: The output of the executed code, converted to string. If there's no explicit output, + the function returns the result of the last line of code. + + Raises: + RuntimeError: If any errors occur during the code execution. + """ + + try: + output = self.client.containers.run( + "continuumio/anaconda3", command=["python", "-c", code] + ).decode() + output = output.strip() + except docker.errors.ContainerError as e: + raise RuntimeError(e) + + return output + + def evaluate_code_using_exec(self, code: str) -> str: + """ + Executes a block of code using the python "exec" function. Returns the output as a string. + Unfortunately it doesn't have the same safety guarantees as the docker version. + + However, it will only ever be enabled when we are in a scale environment as we check the llmengine + path. + + Args: + code (str): A string containing the Python code to be executed. + + Returns: + str: The output of the executed code, converted to string. If there's no explicit output, + the function returns the result of the last line of code. + """ + try: + p = subprocess.run(["python", "-c", code], capture_output=True, text=True) + p.check_returncode() # Raises CalledProcessError if the exit code is non-zero + output_str = p.stdout + + # If output is empty and the last line didn't have a print statement, edit the code to add one + if output_str == "" and "print" not in code.split("\n")[-1]: + new_code = "\n".join(code.split("\n")[:-1]) + last_line = code.split("\n")[-1] + new_code = new_code + f"\nprint({last_line})" + + # Re-run it + p = subprocess.run(["python", "-c", new_code], capture_output=True, text=True) + p.check_returncode() + output_str = p.stdout + + except subprocess.CalledProcessError as e: + raise RuntimeError(p.stderr) from e + + return output_str + + +class Tools(str, Enum): + CODE_EVALUATOR = "code_evaluator" + + +TOOL_MAP = { + Tools.CODE_EVALUATOR: CodeBlockEvaluator, +} diff --git a/model-engine/model_engine_server/inference/tool_completion/utils.py b/model-engine/model_engine_server/inference/tool_completion/utils.py new file mode 100644 index 00000000..bb30d116 --- /dev/null +++ b/model-engine/model_engine_server/inference/tool_completion/utils.py @@ -0,0 +1,107 @@ +from queue import Queue +from typing import Tuple + +from model_engine_server.inference.tool_completion.base import BaseTool + +NAME_ERROR_PATTERN = r"NameError: name \\?'([^']+)\\?' is not defined" + +PRINT_PATTERN = r"print\(.+?\)" + +# Most common imports used during code execution +FIX_ERRORS_MAPPING = { + "math": "import math", + "np": "import numpy as np", + "cmath": "import cmath", + "norm": "from scipy.stats import norm", + "plt": "import matplotlib.pyplot as plt", + "sp": "import sympy as sp", + "sympy": "import sympy", + "sqrt": "from cmath import sqrt", + "erfinv": "from scipy.special import erfinv", + "t": "from scipy.stats import t", + "comb": "from scipy.special import comb", + "Fraction": "from fractions import Fraction", + "st": "import steam_table as st", + "pd": "import pandas as pd", + "stats": "import scipy.stats as stats", + "opt": "import scipy.optimize as opt", + "Counter": "from collections import Counter", + "datetime": "import datetime", + "gcd": "from fractions import gcd", + "pi": "from math import pi", + "quad": "from scipy.integrate import quad", + "fsolve": "from scipy.optimize import fsolve", + "factorial": "from math import factorial", + "tan": "from math import tan", + "log": "from math import log", + "symbols": "from sympy import symbols, sin, cos", + "integrate": "from sympy import symbols, integrate", + "diff": "from sympy import symbols, sin, cos, diff", + "sin": "from sympy import symbols, sin, cos", + "cos": "from sympy import symbols, sin, cos", + "time": "import time", + "Symbol": "from sympy import Symbol", +} + + +# Check if a model response indicates it could be starting a tool +def check_streaming_tool_start(stream_queue: Queue, tool: BaseTool) -> bool: # pragma: no cover + # If the queue is empty, we can't start the tool + if stream_queue.qsize() == 0: + return False + + # Create the full string from the queue + queue_str = "" + for response in list(stream_queue.queue): + queue_str += response.output.text + + # Check if the start token is in the queue + if tool.tool_context_start in queue_str: + return True + + return False + + +def check_either_substr(str1: str, str2: str) -> bool: + return str1 in str2 or str2 in str1 + + +# Check if some responses from the queue should be returned +def get_responses_to_yield( + stream_queue: Queue, tool: BaseTool, tool_started: bool +) -> Tuple[Queue, Queue]: # pragma: no cover + """We return a tuple, (responses_to_yield, stream_queue) based on what should be returned""" + # If we've started the tool, we shouldn't yield anything + if tool_started: + return Queue(), stream_queue + + # Otherwise, we should yield everything in the queue that *can't* be part of the start of a tool + concatenated_queue_str = "" + responses_to_yield: Queue = Queue() # These are values we're sure we want to return right now + undecided_queue: Queue = ( + Queue() + ) # These are values that could be part of start token but we aren't sure yet + + # Iterate through the queue and add to the concatenated queue string + while stream_queue.qsize() > 0: + response = stream_queue.get() + + # First check if the adding the current response could be part of the start token + if check_either_substr( + concatenated_queue_str + response.output.text, tool.tool_context_start + ): + # If so, add it to the undecided queue + undecided_queue.put(response) + concatenated_queue_str += response.output.text + + # Otherwise, we are confident that everything in the undecided *can't* be part of the start token + # in addition to the concatenated queue string + else: + while not undecided_queue.empty(): + responses_to_yield.put(undecided_queue.get()) + + responses_to_yield.put(response) + concatenated_queue_str = "" + + # Finally, return the responses to yield and the new stream queue + return responses_to_yield, undecided_queue diff --git a/model-engine/model_engine_server/inference/user.Dockerfile b/model-engine/model_engine_server/inference/user.Dockerfile new file mode 100644 index 00000000..6ed69146 --- /dev/null +++ b/model-engine/model_engine_server/inference/user.Dockerfile @@ -0,0 +1,8 @@ +ARG BASE_IMAGE +FROM ${BASE_IMAGE} + +ARG REQUIREMENTS_FILE +COPY --chown=root ${REQUIREMENTS_FILE} /app/model-engine/model_engine_server/inference/requirements.txt +RUN PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf pip install -r /app/model-engine/model_engine_server/inference/requirements.txt + +ENV PYTHONPATH /app diff --git a/model-engine/model_engine_server/inference/utils.py b/model-engine/model_engine_server/inference/utils.py new file mode 100644 index 00000000..e6137415 --- /dev/null +++ b/model-engine/model_engine_server/inference/utils.py @@ -0,0 +1,117 @@ +import asyncio +import subprocess +import sys +import uuid +from typing import Any, AsyncIterator, Coroutine, Tuple, Union + +from typing_extensions import TypeVar + + +def get_cpu_cores_in_container() -> int: + import multiprocessing + + cpu_count = multiprocessing.cpu_count() + try: + with open("/sys/fs/cgroup/cpu/cpu.cfs_quota_us") as fp: + cfs_quota_us = int(fp.read()) + with open("/sys/fs/cgroup/cpu/cpu.cfs_period_us") as fp: + cfs_period_us = int(fp.read()) + if cfs_quota_us != -1: + cpu_count = cfs_quota_us // cfs_period_us + except FileNotFoundError: + pass + return cpu_count + + +def get_gpu_free_memory(): # pragma: no cover + """Get GPU free memory using nvidia-smi.""" + try: + output = subprocess.run( + ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"], + capture_output=True, + text=True, + ).stdout + gpu_memory = [int(x) for x in output.strip().split("\n")] + return gpu_memory + except Exception as e: + print(f"Error getting GPU memory: {e}") + return None + + +def check_unknown_startup_memory_usage(): # pragma: no cover + """Check for unknown memory usage at startup.""" + gpu_free_memory = get_gpu_free_memory() + if gpu_free_memory is not None: + print(f"GPU free memory at startup in MB: {gpu_free_memory}") + min_mem = min(gpu_free_memory) + max_mem = max(gpu_free_memory) + if max_mem - min_mem > 10: + print( + f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}." + ) + try: + output = subprocess.run( + ["fuser -v /dev/nvidia*"], + shell=True, # nosemgrep + capture_output=True, + text=True, + ).stdout + print(f"Processes using GPU: {output}") + except Exception as e: + print(f"Error getting processes using GPU: {e}") + + +def random_uuid() -> str: + return str(uuid.uuid4()) + + +T = TypeVar("T") + + +class ProducerFinished: + pass + + +def await_coroutines(*coroutines: Coroutine[Any, Any, T]) -> AsyncIterator[Tuple[int, T]]: + """Await multiple coroutines concurrently. + + Returns an async iterator that yields the results of the coroutines as they complete. + """ + queue: asyncio.Queue[Union[Tuple[int, T], ProducerFinished, Exception]] = asyncio.Queue() + + async def producer(i: int, coroutine: Coroutine[Any, Any, T]): + try: + result = await coroutine + await queue.put((i, result)) + except Exception as e: + await queue.put(e) + # Signal to the consumer that we've finished + await queue.put(ProducerFinished()) + + _tasks = [asyncio.create_task(producer(i, coroutine)) for i, coroutine in enumerate(coroutines)] + + async def consumer(): + remaining = len(coroutines) + try: + while remaining or not queue.empty(): + item = await queue.get() + + if isinstance(item, ProducerFinished): + # Signal that a producer finished- not a real item + remaining -= 1 + continue + + if isinstance(item, Exception): + raise item + yield item + except (Exception, asyncio.CancelledError) as e: + for task in _tasks: + if sys.version_info >= (3, 9): + # msg parameter only supported in Python 3.9+ + task.cancel(e) + else: + task.cancel() + raise e + await asyncio.gather(*_tasks) + + return consumer() diff --git a/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm new file mode 100644 index 00000000..76085416 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/Dockerfile.vllm @@ -0,0 +1,51 @@ +# syntax=docker/dockerfile:1 +ARG VLLM_VERSION=0.6.3 +ARG VLLM_BASE_REPO=vllm/vllm-openai +ARG VLLM_BASE_IMAGE=${VLLM_BASE_REPO}:v${VLLM_VERSION} +FROM ${VLLM_BASE_IMAGE} AS base + +RUN apt-get update \ + && apt-get install -y wget gdb psmisc dumb-init \ + && apt-get autoremove -y \ + && rm -rf /var/lib/apt/lists/* \ + apt-get clean + +WORKDIR /workspace + +RUN wget https://github.com/peak/s5cmd/releases/download/v2.2.1/s5cmd_2.2.1_Linux-64bit.tar.gz +RUN tar -xvzf s5cmd_2.2.1_Linux-64bit.tar.gz + +# symlink python to python3 +RUN ln -s /usr/bin/python3 /usr/bin/python + +FROM base AS vllm + +COPY model-engine/model_engine_server/inference/vllm/vllm_server.py /workspace/vllm_server.py +COPY model-engine/model_engine_server/inference/vllm/init_ray.sh /workspace/init_ray.sh + +# Need to override entrypoint from parent image +ENTRYPOINT ["/bin/env"] + +FROM base AS vllm_batch + +COPY model-engine/model_engine_server/inference/batch_inference/requirements.txt /workspace/requirements.txt +RUN pip install -r requirements.txt + +COPY model-engine /workspace/model-engine +RUN pip install -e /workspace/model-engine +COPY model-engine/model_engine_server/inference/batch_inference/vllm_batch.py /workspace/vllm_batch.py + +# Need to override entrypoint from parent image +ENTRYPOINT ["/bin/env"] + +FROM base AS vllm_batch_v2 + +COPY model-engine/model_engine_server/inference/vllm/requirements-batch.txt /workspace/requirements.txt +RUN pip install -r requirements.txt + +COPY model-engine /workspace/model-engine +RUN pip install -e /workspace/model-engine +COPY model-engine/model_engine_server/inference/vllm/vllm_batch.py /workspace/vllm_batch.py + +# Need to override entrypoint from parent image +ENTRYPOINT ["/bin/env"] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/README.md b/model-engine/model_engine_server/inference/vllm/README.md new file mode 100644 index 00000000..486b528c --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/README.md @@ -0,0 +1,62 @@ +# VLLM + +## Building container + +There are three build targets for vLLM. +1. vLLM endpoint +2. vLLM batch job v1 +3. vLLM batch job v2 + +```bash +VLLM_VERSION=0.5.4 bash build_and_upload_image.sh $ACCOUNT_ID $IMAGE_TAG {BUILD_TARGET=vllm|vllm_batch|vllm_batch_v2} +``` + +## Running locally + +### Endpoint + +1. Download model weights to `model_files` +2. Run docker locally +```bash +IMAGE=${ACCOUNT_ID}.dkr.ecr.us-west-2.amazonaws.com/vllm:${IMAGE_TAG} +docker kill vllm; docker rm vllm; +docker run \ + --runtime nvidia \ + --shm-size=16gb \ + --gpus '"device=0"' \ + -v $MODEL_PATH:/workspace/model_files:ro \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/vllm_server.py:/workspace/vllm_server.py \ + -p 5005:5005 \ + --name vllm \ + ${IMAGE} \ + python -m vllm_server --model model_files --tensor-parallel-size 1 --port 5005 --disable-log-requests +``` + +3. Send curl requests +```bash +curl -X POST localhost:5005/v1/chat/completions \ + -H "Content-Type: application/json" \ + -d '{"messages":[{"role": "user", "content": "Hey, whats the temperature in Paris right now?"}],"model":"model_files","max_tokens":100,"temperature":0.2,"guided_regex":"Sean.*"}' +``` + +### Batch job v2 +```bash +IMAGE_BATCH=${ACCOUNT_ID}.dkr.ecr.us-west-2.amazonaws.com/llm-engine/batch-infer-vllm:${IMAGE_TAG} + +export MODEL=gemma-2-2b-it && export MODEL_PATH=/data/model_files/$MODEL +docker kill vllm_batch; docker rm vllm_batch; +docker run \ + --runtime nvidia \ + --shm-size=16gb \ + --gpus '"device=6,7"' \ + -v $MODEL_PATH:/workspace/model_files:ro \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/examples:/workspace/examples \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/vllm_batch.py:/workspace/vllm_batch.py \ + -p 5005:5005 \ + -e CONFIG_FILE=/workspace/examples/v2/gemma/config.json \ + -e MODEL_WEIGHTS_FOLDER=/workspace/model_files \ + --name vllm_batch \ + ${IMAGE_BATCH} \ + python vllm_batch.py + +``` \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh new file mode 100755 index 00000000..3b1ab4cb --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/build_and_upload_image.sh @@ -0,0 +1,49 @@ +#!/bin/bash + +set -eo pipefail + +# Build and push vLLM docker image to AWS ECR. +# +# Usage: VLLM_VERSION=0.6.3 ./build_and_upload_image.sh vllm|vllm_batch|vllm_batch_v2 + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +PROJECT_DIR=$SCRIPT_DIR/../../../.. +DOCKERFILE=$PROJECT_DIR/model_engine_server/inference/vllm/Dockerfile.vllm + +if [ -z "$1" ]; then + echo "Must supply AWS account ID" + exit 1; +fi + +if [ -z "$2" ]; then + echo "Must supply the image tag" + exit 1; +fi + +if [ -z "$3" ]; then + echo "Must supply the build target (either vllm or vllm_batch_v2)" + exit 1; +fi + + +ACCOUNT=$1 +IMAGE_TAG=$2 +BUILD_TARGET=$3 +VLLM_VERSION=${VLLM_VERSION:-"0.6.2"} +VLLM_BASE_REPO=${VLLM_BASE_REPO:-"vllm/vllm-openai"} + +# if build target = vllm use vllm otherwise use vllm_batch +if [ "$BUILD_TARGET" == "vllm" ]; then + IMAGE=$ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/vllm:$IMAGE_TAG +else + IMAGE=$ACCOUNT.dkr.ecr.us-west-2.amazonaws.com/llm-engine/batch-infer-vllm:$IMAGE_TAG +fi + +aws ecr get-login-password --region us-west-2 | docker login --username AWS --password-stdin $ACCOUNT.dkr.ecr.us-west-2.amazonaws.com +DOCKER_BUILDKIT=1 docker build \ + --build-arg VLLM_VERSION=${VLLM_VERSION} \ + --build-arg VLLM_BASE_REPO=${VLLM_BASE_REPO} \ + -f ${DOCKERFILE} \ + --target ${BUILD_TARGET} \ + -t $IMAGE ${PROJECT_DIR} +docker push $IMAGE diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/README.md b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/README.md new file mode 100644 index 00000000..08c3b213 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/README.md @@ -0,0 +1,19 @@ +# quick commands + +``` +export MODEL=gemma-2-2b-it && export MODEL_PATH=/data/model_files/$MODEL +docker kill vllm_batch; docker rm vllm_batch; +docker run \ + --runtime nvidia \ + --shm-size=16gb \ + --gpus '"device=6,7"' \ + -v $MODEL_PATH:/workspace/model_files:ro \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/examples:/workspace/examples \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/vllm_batch.py:/workspace/vllm_batch.py \ + -p 5005:5005 \ + -e CONFIG_FILE=/workspace/examples/v2/gemma/config.json \ + -e MODEL_WEIGHTS_FOLDER=/workspace/model_files \ + --name vllm_batch \ + ${IMAGE_BATCH} \ + python vllm_batch.py +``` \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/config.json b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/config.json new file mode 100644 index 00000000..fc98e6d0 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/config.json @@ -0,0 +1,15 @@ +{ + "input_data_path": "./examples/v2/gemma/data_oai_chat.json", + "output_data_path": "./examples/v2/gemma/output_oi_chat.json", + "model_config": { + "model": "gemma-2-2b-it", + "checkpoint_path": "my_path", + "num_shards": 1, + "response_role": "assistant", + "labels": { + "team": "my_team" + } + }, + "attention_backend": "FLASHINFER", + "data_parallelism": 1 +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/config_w_oai_chat_content.json b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/config_w_oai_chat_content.json new file mode 100644 index 00000000..15e35c4d --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/config_w_oai_chat_content.json @@ -0,0 +1,33 @@ +{ + "content": [ + { + "messages": [ + { + "role": "user", + "content": "What is a good place for travel in the US?" + }, + { + "role": "assistant", + "content": "California." + }, + { + "role": "user", + "content": "What can I do in California?" + } + ], + "logprobs": true + } + ], + "output_data_path": "./examples/v2/sample_output.json", + "model_config": { + "model": "gemma-2-2b-it", + "checkpoint_path": "my_path", + "num_shards": 1, + "response_role": "assistant", + "labels": { + "team": "my_team" + } + }, + "attention_backend": "FLASHINFER", + "data_parallelism": 1 +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_chat.json b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_chat.json new file mode 100644 index 00000000..fbbf1286 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_chat.json @@ -0,0 +1,7 @@ +[ + {"messages": [ + {"role": "user", "content": "What is a good place for travel in the US?"}, + {"role": "assistant", "content": "California."}, + {"role": "user", "content": "What can I do in California?"}], + "logprobs": true} +] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_completion.json b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_completion.json new file mode 100644 index 00000000..2b1500b4 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/gemma/data_oai_completion.json @@ -0,0 +1,16 @@ +[ + { + "prompt": "What is a good place for travel in the US?", + "logprobs": true, + "echo": true, + "max_tokens": 7, + "temperature": 1 + }, + { + "prompt": "What is a good place for travel in the EU?", + "logprobs": true, + "echo": true, + "max_tokens": 7, + "temperature": 1 + } +] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/README.md b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/README.md new file mode 100644 index 00000000..cc7f3052 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/README.md @@ -0,0 +1,19 @@ +# quick commands + +``` +export MODEL=meta-llama/Llama-3.2-11B-Vision-Instruct && export MODEL_PATH=/data/model_files/$MODEL +docker kill vllm_batch; docker rm vllm_batch; +docker run \ + --runtime nvidia \ + --shm-size=16gb \ + --gpus '"device=6,7"' \ + -v $MODEL_PATH:/workspace/model_files:ro \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/examples:/workspace/examples \ + -v ${REPO_PATH}/llm-engine/model-engine/model_engine_server/inference/vllm/vllm_batch.py:/workspace/vllm_batch.py \ + -p 5005:5005 \ + -e CONFIG_FILE=/workspace/examples/v2/llama-3.2-vision/config.json \ + -e MODEL_WEIGHTS_FOLDER=/workspace/model_files \ + --name vllm_batch \ + ${IMAGE_BATCH} \ + python vllm_batch.py +``` \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/config.json b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/config.json new file mode 100644 index 00000000..a26a3b3c --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/config.json @@ -0,0 +1,18 @@ +{ + "input_data_path": "./examples/v2/llama-3.2-vision/data_oai_chat.json", + "output_data_path": "./examples/v2/llama-3.2-vision/output_oi_chat.json", + "model_config": { + "model": "meta-llama/Llama-3.2-11B-Vision-Instruct", + "checkpoint_path": "my_path", + "num_shards": 1, + "max_model_len": 4096, + "max_num_seqs": 16, + "enforce_eager": true, + "response_role": "assistant", + "labels": { + "team": "my_team" + } + }, + "attention_backend": "FLASHINFER", + "data_parallelism": 1 +} \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/data_oai_chat.json b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/data_oai_chat.json new file mode 100644 index 00000000..2cdc9656 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/data_oai_chat.json @@ -0,0 +1,22 @@ +[ + { + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What's in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + } + } + ] + } + ], + "max_tokens": 64 + } +] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/output_oi_chat.json b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/output_oi_chat.json new file mode 100644 index 00000000..402429c6 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/examples/v2/llama-3.2-vision/output_oi_chat.json @@ -0,0 +1 @@ +[{"id": "chat-b61abe3898714576802d92f36ab90c38", "object": "chat.completion", "created": 1727669398, "model": "/workspace/model_files", "choices": [{"index": 0, "message": {"role": "assistant", "content": "This image depicts a serene landscape with a long wooden boardwalk or path that stretches out into a field dotted with long green grass in the foreground and tall green and yellow grass and green and red shrubbery on the side of the path. In the background, there are large, short and thick green and yellow shrubs", "tool_calls": []}, "logprobs": null, "finish_reason": "length", "stop_reason": null}], "usage": {"prompt_tokens": 17, "total_tokens": 81, "completion_tokens": 64}, "prompt_logprobs": null}] \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/gen_sample_data.py b/model-engine/model_engine_server/inference/vllm/gen_sample_data.py new file mode 100644 index 00000000..2b2e9367 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/gen_sample_data.py @@ -0,0 +1,36 @@ +import json + +EXAMPLES_DIR = "examples/v2" + +messages = [ + { + "role": "user", + "content": "What is a good place for travel in the US?", + }, + { + "role": "assistant", + "content": "California.", + }, + { + "role": "user", + "content": "What can I do in California?", + }, +] + +if __name__ == "__main__": + + completion_type = "chat" + model = "gemma" + target_file = f"{EXAMPLES_DIR}/sample_data_{completion_type}_{model}.json" + + # request = CompletionCreateParamsNonStreaming( + # messages=messages, + # logprobs=True, + # max_tokens=300, + # ) + request = { + "messages": messages, + "logprobs": True, + "max_tokens": 300, + } + json.dump([request], open(target_file, "w")) diff --git a/model-engine/model_engine_server/inference/vllm/init_ray.sh b/model-engine/model_engine_server/inference/vllm/init_ray.sh new file mode 100755 index 00000000..f685206a --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/init_ray.sh @@ -0,0 +1,101 @@ +#!/bin/bash + +# From https://github.com/kubernetes-sigs/lws/blob/main/docs/examples/vllm/build/ray_init.sh +subcommand=$1 +shift + +ray_port=6379 +ray_init_timeout=1200 # Needs to be large enough to overcome any skew from the s5cmd command + any pod startup time + +case "$subcommand" in + worker) + ray_address="" + while [ $# -gt 0 ]; do + case "$1" in + --ray_address=*) + ray_address="${1#*=}" + ;; + --ray_port=*) + ray_port="${1#*=}" + ;; + --ray_init_timeout=*) + ray_init_timeout="${1#*=}" + ;; + --own_address=*) + own_address="${1#*=}" + ;; + *) + echo "unknown argument: $1" + exit 1 + esac + shift + done + + if [ -z "$ray_address" ]; then + echo "Error: Missing argument --ray_address" + exit 1 + fi + for (( i=0; i < $ray_init_timeout; i+=5 )); do + ray start --address=$ray_address:$ray_port --block --node-ip-address=$own_address + if [ $? -eq 0 ]; then + echo "Worker: Ray runtime started with head address $ray_address:$ray_port" + exit 0 + fi + echo $? + echo "Waiting until the ray worker is active..." + sleep 5s; + done + echo "Ray worker starts timeout, head address: $ray_address:$ray_port" + exit 1 + ;; + + leader) + ray_cluster_size="" + while [ $# -gt 0 ]; do + case "$1" in + --ray_port=*) + ray_port="${1#*=}" + ;; + --ray_cluster_size=*) + ray_cluster_size="${1#*=}" + ;; + --ray_init_timeout=*) + ray_init_timeout="${1#*=}" + ;; + --own_address=*) + own_address="${1#*=}" + ;; + *) + echo "unknown argument: $1" + exit 1 + esac + shift + done + + if [ -z "$ray_cluster_size" ]; then + echo "Error: Missing argument --ray_cluster_size" + exit 1 + fi + + # start the ray daemon + ray start --head --port=$ray_port --node-ip-address=$own_address + # wait until all workers are active + for (( i=0; i < $ray_init_timeout; i+=5 )); do + active_nodes=`python3 -c 'import ray; ray.init(); print(sum(node["Alive"] for node in ray.nodes()))'` + if [ $active_nodes -eq $ray_cluster_size ]; then + echo "All ray workers are active and the ray cluster is initialized successfully." + exit 0 + fi + echo "Wait for all ray workers to be active. $active_nodes/$ray_cluster_size is active" + sleep 5s; + done + + echo "Waiting for all ray workers to be active timed out." + exit 1 + ;; + + *) + echo "unknown subcommand: $subcommand" + exit 1 + ;; +esac diff --git a/model-engine/model_engine_server/inference/vllm/requirements-batch.txt b/model-engine/model_engine_server/inference/vllm/requirements-batch.txt new file mode 100644 index 00000000..04afaa23 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/requirements-batch.txt @@ -0,0 +1,7 @@ +pydantic>=2.8 +boto3==1.34.15 +smart-open==6.4.0 +ddtrace==2.11.0 +datadog==0.49.1 +dataclasses-json~=0.6.7 +sse-starlette==2.1.3 \ No newline at end of file diff --git a/model-engine/model_engine_server/inference/vllm/requirements-dev.txt b/model-engine/model_engine_server/inference/vllm/requirements-dev.txt new file mode 100644 index 00000000..b75668a1 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/requirements-dev.txt @@ -0,0 +1 @@ +vllm==0.6.3 diff --git a/model-engine/model_engine_server/inference/vllm/requirements.txt b/model-engine/model_engine_server/inference/vllm/requirements.txt new file mode 100644 index 00000000..3381d938 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/requirements.txt @@ -0,0 +1 @@ +pydantic>=2.0 diff --git a/model-engine/model_engine_server/inference/vllm/vllm_batch.py b/model-engine/model_engine_server/inference/vllm/vllm_batch.py new file mode 100644 index 00000000..d24ca74e --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/vllm_batch.py @@ -0,0 +1,372 @@ +import argparse +import asyncio +import json +import os +import subprocess +from typing import ( + Any, + AsyncGenerator, + AsyncIterator, + Coroutine, + Dict, + List, + MutableMapping, + Optional, + Union, +) + +import smart_open +from fastapi import Request +from model_engine_server.common.dtos.llms import ( + BatchCompletionContent, + BatchCompletionsModelConfig, + CompletionResponse, + CompletionV1Output, + CreateBatchCompletionsEngineRequest, + CreateBatchCompletionsV1RequestContent, + TokenOutput, + VLLMModelConfig, +) +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( + DatadogInferenceMonitoringMetricsGateway, +) +from model_engine_server.inference.utils import ( + await_coroutines, + check_unknown_startup_memory_usage, + get_cpu_cores_in_container, + random_uuid, +) +from pydantic import TypeAdapter +from starlette.datastructures import Headers +from tqdm import tqdm +from typing_extensions import TypeAlias, assert_never +from vllm import AsyncEngineArgs, AsyncLLMEngine, RequestOutput, SamplingParams +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.openai.protocol import ChatCompletionRequest, CompletionRequest, ErrorResponse +from vllm.entrypoints.openai.serving_chat import OpenAIServingChat +from vllm.entrypoints.openai.serving_completion import OpenAIServingCompletion +from vllm.entrypoints.openai.serving_engine import BaseModelPath +from vllm.utils import merge_async_iterators + +CONFIG_FILE = os.getenv("CONFIG_FILE") +AWS_REGION = os.getenv("AWS_REGION", "us-west-2") +MODEL_WEIGHTS_FOLDER = os.getenv("MODEL_WEIGHTS_FOLDER", "./model_weights") + +SKIP_AWS_PROFILE_SET = os.getenv("SKIP_AWS_PROFILE_SET", "false").lower() == "true" +if not SKIP_AWS_PROFILE_SET: + os.environ["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") + + +openai_serving_chat: OpenAIServingChat +openai_serving_completion: OpenAIServingCompletion + +CPU_COUNT = get_cpu_cores_in_container() + +_BatchCompletionContent: TypeAlias = Union[ + CreateBatchCompletionsV1RequestContent, + List[CompletionRequest], + List[ChatCompletionRequest], +] + + +async def dummy_receive() -> MutableMapping[str, Any]: + return {"type": "continue"} + + +# jank but create_completion expects a FastAPI Request object +dummy_request = Request( + scope={ + "type": "http", + "path": "/", + "headers": Headers().raw, + "http_version": "1.1", + "method": "GET", + "scheme": "https", + "client": ("127.0.0.1", 8080), + }, + # receive fn that doesn't terminate + receive=dummy_receive, +) + + +async def download_model(checkpoint_path: str, target_dir: str, trust_remote_code: bool) -> None: + additional_include = "--include '*.py'" if trust_remote_code else "" + s5cmd = f"./s5cmd --numworkers 512 sync --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' {additional_include} --exclude 'optimizer*' --exclude 'train*' {os.path.join(checkpoint_path, '*')} {target_dir}" + env = os.environ.copy() + env["AWS_PROFILE"] = os.getenv("S3_WRITE_AWS_PROFILE", "default") + # Need to override these env vars so s5cmd uses AWS_PROFILE + env["AWS_ROLE_ARN"] = "" + env["AWS_WEB_IDENTITY_TOKEN_FILE"] = "" + process = subprocess.Popen( + s5cmd, + shell=True, # nosemgrep + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + text=True, + env=env, + ) + if process.stdout: + for line in process.stdout: + print(line, flush=True) + + process.wait() + + if process.returncode != 0 and process.stderr: + stderr_lines = [] + for line in iter(process.stderr.readline, ""): + stderr_lines.append(line.strip()) + + print(f"Error downloading model weights: {stderr_lines}", flush=True) + + +async def generate_v1_completions( + engine: EngineClient, + content: CreateBatchCompletionsV1RequestContent, +) -> List[Optional[CompletionV1Output]]: + prompts = content.prompts + bar = tqdm(total=len(prompts), desc="Processed prompts") + sampling_params = SamplingParams( + max_tokens=content.max_new_tokens, + temperature=content.temperature, + stop=content.stop_sequences, + logprobs=1 if content.return_token_log_probs else None, + presence_penalty=content.presence_penalty or 0.0, + frequency_penalty=content.frequency_penalty or 0.0, + top_k=content.top_k or -1, + top_p=content.top_p or 1.0, + skip_special_tokens=( + content.skip_special_tokens if content.skip_special_tokens is not None else True + ), + ) + + results_generators: List[AsyncIterator[RequestOutput]] = [] + for prompt in prompts: + request_id = random_uuid() + results_generator = engine.generate( + prompt, + sampling_params=sampling_params, + request_id=request_id, + ) + results_generators.append(results_generator) + + return_token_log_probs = True + + generator = merge_async_iterators(*results_generators) + outputs: List[Optional[CompletionV1Output]] = [None] * len(prompts) + tokens: List[List[TokenOutput]] = [list() for _ in prompts] + async for i, res in generator: + # There should only be one output + output = res.outputs[-1] + + if return_token_log_probs and output.logprobs is not None: + # Sometime the logprobs are not present in the output + logprobs = output.logprobs[-1] + for token_id in logprobs.keys(): + tokens[i].append( + TokenOutput( + token=logprobs[token_id].decoded_token, + log_prob=logprobs[token_id].logprob, + ) + ) + + if res.finished: + outputs[i] = CompletionV1Output( + text=output.text, + num_prompt_tokens=len(res.prompt_token_ids), + num_completion_tokens=len(output.token_ids), + tokens=[ + token.model_dump() for token in tokens[i] + ], # Not sure why, but pydantic doesn't like when I pass it TokenOutput directly but works when I encode it as a dict... + ) + bar.update(1) + + return outputs + + +async def generate_v2_completions( + engine: EngineClient, + requests: Union[List[CompletionRequest], List[ChatCompletionRequest]], +) -> List[Union[CompletionResponse, ErrorResponse, None]]: + bar = tqdm(total=len(requests), desc="Processed requests") + results_generators: List[ + Coroutine[ + Any, + Any, + Union[ErrorResponse, AsyncGenerator[str, None], CompletionResponse], + ] + ] = [] + for request in requests: + if isinstance(request, CompletionRequest): + results_generators.append( + openai_serving_completion.create_completion(request, dummy_request) + ) + elif isinstance(request, ChatCompletionRequest): + results_generators.append(openai_serving_chat.create_chat_completion(request)) + else: + assert_never(request) + + results_generator = await_coroutines(*results_generators) + outputs: List[Optional[CompletionResponse]] = [None] * len(requests) + + async for i, res in results_generator: + if isinstance(res, AsyncGenerator): + continue + outputs[i] = res + bar.update(1) + return outputs + + +async def generate_completions( + engine: EngineClient, request: _BatchCompletionContent +) -> Union[List[Optional[CompletionV1Output]], List[Optional[CompletionResponse]]]: + if isinstance(request, CreateBatchCompletionsV1RequestContent): + return await generate_v1_completions(engine, request) + elif isinstance(request, List): + return await generate_v2_completions(engine, request) + else: + assert_never(request) + + +async def init_engine( + model: str, + request: CreateBatchCompletionsEngineRequest, +) -> EngineClient: + global openai_serving_chat + global openai_serving_completion + + if request.attention_backend is not None: + os.environ["ATTENTION_BACKEND"] = request.attention_backend + + parsed_configs = VLLMModelConfig.model_validate_json(request.model_cfg.model_dump_json()) + if not parsed_configs.max_model_len: + parsed_configs.max_model_len = request.model_cfg.max_context_length + + print("VLLM additional configs:", parsed_configs.model_dump()) + + engine_args_dict = parsed_configs.model_dump(exclude_none=True) + default_engine_args_dict = dict( + model=model, + tensor_parallel_size=request.model_cfg.num_shards, + seed=request.model_cfg.seed or 0, + disable_log_requests=True, + gpu_memory_utilization=request.max_gpu_memory_utilization or 0.9, + ) + default_engine_args_dict.update(engine_args_dict) + + engine_args = AsyncEngineArgs(**default_engine_args_dict) + + engine_client = AsyncLLMEngine.from_engine_args(engine_args) + model_config = await engine_client.get_model_config() + base_model_paths = [BaseModelPath(name=model, model_path=model)] + + openai_serving_chat = OpenAIServingChat( + engine_client, + model_config, + base_model_paths, + response_role=request.model_cfg.response_role or "assistant", + lora_modules=None, + prompt_adapters=None, + request_logger=None, + chat_template=None, + ) + + openai_serving_completion = OpenAIServingCompletion( + engine_client, + model_config, + base_model_paths, + lora_modules=None, + prompt_adapters=None, + request_logger=None, + ) + + return engine_client + + +def overwrite_request(request: Dict[str, Any], model: str) -> Dict[str, Any]: + request["model"] = model + request["stream"] = False + return request + + +def load_batch_content( + request: CreateBatchCompletionsEngineRequest, +) -> _BatchCompletionContent: + content = request.content + if content is None: + with smart_open.open(request.input_data_path, "r") as f: + data = json.load(f) + content = TypeAdapter(BatchCompletionContent).validate_python(data) + + # Recast the content to vLLMs schema + if isinstance(content, List) and len(content) > 0: + model = get_model_name(request.model_cfg) + return TypeAdapter( + Union[List[CompletionRequest], List[ChatCompletionRequest]] + ).validate_python( + [ + overwrite_request(req.model_dump(exclude_none=True, mode="json"), model) + for req in content + ] + ) + + return content + + +def get_model_name(model_config: BatchCompletionsModelConfig) -> str: + return MODEL_WEIGHTS_FOLDER if model_config.checkpoint_path else model_config.model + + +async def handle_batch_job(request: CreateBatchCompletionsEngineRequest) -> None: + metrics_gateway = DatadogInferenceMonitoringMetricsGateway() + + model = get_model_name(request.model_cfg) + if request.model_cfg.checkpoint_path: + await download_model( + checkpoint_path=request.model_cfg.checkpoint_path, + target_dir=MODEL_WEIGHTS_FOLDER, + trust_remote_code=request.model_cfg.trust_remote_code or False, + ) + + content = load_batch_content(request) + engine = await init_engine( + model, + request=request, + ) + + outputs = await generate_completions(engine, content) + with smart_open.open(request.output_data_path, "w") as f: + f.write(json.dumps([output.model_dump() if output else None for output in outputs])) + + metrics_gateway.emit_batch_completions_metric( + model, + use_tool=False, + num_prompt_tokens=0, + num_completion_tokens=0, + is_finetuned=True, + ) + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "--config-file-data", + "--config_file_data", + type=str, + default=None, + help="Optional override for the config file data, as a json string", + ) + + args = parser.parse_args() + + check_unknown_startup_memory_usage() + + config_file_data = args.config_file_data + if config_file_data is None: + if CONFIG_FILE is None or not os.path.exists(CONFIG_FILE): + raise FileNotFoundError(f"Config file {CONFIG_FILE} not found") + with open(CONFIG_FILE, "r") as f: + config_file_data = f.read() + + request = CreateBatchCompletionsEngineRequest.model_validate_json(config_file_data) + + asyncio.run(handle_batch_job(request)) diff --git a/model-engine/model_engine_server/inference/vllm/vllm_server.py b/model-engine/model_engine_server/inference/vllm/vllm_server.py new file mode 100644 index 00000000..183d5c64 --- /dev/null +++ b/model-engine/model_engine_server/inference/vllm/vllm_server.py @@ -0,0 +1,245 @@ +import asyncio +import code +import json +import os +import signal +import socket +import subprocess +import traceback +from logging import Logger +from typing import AsyncGenerator, Dict, List, Optional + +from fastapi import APIRouter, BackgroundTasks, Request +from fastapi.responses import Response, StreamingResponse +from vllm.engine.async_llm_engine import ( + AsyncEngineDeadError, + build_guided_decoding_logits_processor_async, +) +from vllm.engine.protocol import EngineClient +from vllm.entrypoints.launcher import serve_http +from vllm.entrypoints.openai.api_server import build_app, build_async_engine_client, init_app_state +from vllm.entrypoints.openai.cli_args import make_arg_parser +from vllm.outputs import CompletionOutput +from vllm.sampling_params import SamplingParams +from vllm.sequence import Logprob +from vllm.utils import FlexibleArgumentParser, random_uuid +from vllm.version import __version__ as VLLM_VERSION + +logger = Logger("vllm_server") + +engine_client: EngineClient + +TIMEOUT_KEEP_ALIVE = 5 # seconds. +TIMEOUT_TO_PREVENT_DEADLOCK = 1 # seconds + +router = APIRouter() + + +@router.post("/predict") +@router.post("/stream") +async def generate(request: Request) -> Response: + """Generate completion for the request. + + The request should be a JSON object with the following fields: + - prompt: the prompt to use for the generation. + - stream: whether to stream the results or not. + - other fields: the sampling parameters (See `SamplingParams` for details). + """ + # check health before accepting request and fail fast if engine isn't healthy + try: + await engine_client.check_health() + + request_dict = await request.json() + prompt = request_dict.pop("prompt") + stream = request_dict.pop("stream", False) + + guided_decoding_backend = ( + await engine_client.get_decoding_config() + ).guided_decoding_backend + + sampling_params = await build_guided_decoding_logits_processor_async( + sampling_params=SamplingParams(**request_dict), + tokenizer=await engine_client.get_tokenizer(lora_request=None), + default_guided_backend=guided_decoding_backend, + ) + + request_id = random_uuid() + + results_generator = engine_client.generate(prompt, sampling_params, request_id) + + async def abort_request() -> None: + await engine_client.abort(request_id) + + if stream: + # Streaming case + async def stream_results() -> AsyncGenerator[str, None]: + last_output_text = "" + async for request_output in results_generator: + log_probs = format_logprobs(request_output) + ret = { + "text": request_output.outputs[-1].text[len(last_output_text) :], + "count_prompt_tokens": len(request_output.prompt_token_ids), + "count_output_tokens": len(request_output.outputs[0].token_ids), + "log_probs": ( + log_probs[-1] if log_probs and sampling_params.logprobs else None + ), + "finished": request_output.finished, + } + last_output_text = request_output.outputs[-1].text + yield f"data:{json.dumps(ret)}\n\n" + + background_tasks = BackgroundTasks() + # Abort the request if the client disconnects. + background_tasks.add_task(abort_request) + + return StreamingResponse(stream_results(), background=background_tasks) + + # Non-streaming case + final_output = None + tokens = [] + last_output_text = "" + async for request_output in results_generator: + tokens.append(request_output.outputs[-1].text[len(last_output_text) :]) + last_output_text = request_output.outputs[-1].text + if await request.is_disconnected(): + # Abort the request if the client disconnects. + await engine_client.abort(request_id) + return Response(status_code=499) + final_output = request_output + + assert final_output is not None + prompt = final_output.prompt + ret = { + "text": final_output.outputs[0].text, + "count_prompt_tokens": len(final_output.prompt_token_ids), + "count_output_tokens": len(final_output.outputs[0].token_ids), + "log_probs": format_logprobs(final_output), + "tokens": tokens, + } + return Response(content=json.dumps(ret)) + + except AsyncEngineDeadError as e: + logger.error(f"The vllm engine is dead, exiting the pod: {e}") + os.kill(os.getpid(), signal.SIGINT) + raise e + + +def get_gpu_free_memory(): + """Get GPU free memory using nvidia-smi.""" + try: + output = subprocess.run( + ["nvidia-smi", "--query-gpu=memory.free", "--format=csv,noheader,nounits"], + capture_output=True, + text=True, + ).stdout + gpu_memory = [int(x) for x in output.strip().split("\n")] + return gpu_memory + except Exception as e: + logger.warn(f"Error getting GPU memory: {e}") + return None + + +def check_unknown_startup_memory_usage(): + """Check for unknown memory usage at startup.""" + gpu_free_memory = get_gpu_free_memory() + if gpu_free_memory is not None: + min_mem = min(gpu_free_memory) + max_mem = max(gpu_free_memory) + if max_mem - min_mem > 10: + logger.warn( + f"WARNING: Unbalanced GPU memory usage at start up. This may cause OOM. Memory usage per GPU in MB: {gpu_free_memory}." + ) + try: + # nosemgrep + output = subprocess.run( + ["fuser -v /dev/nvidia*"], + shell=False, + capture_output=True, + text=True, + ).stdout + logger.info(f"Processes using GPU: {output}") + except Exception as e: + logger.error(f"Error getting processes using GPU: {e}") + + +def debug(sig, frame): + """Interrupt running process, and provide a python prompt for + interactive debugging.""" + d = {"_frame": frame} # Allow access to frame object. + d.update(frame.f_globals) # Unless shadowed by global + d.update(frame.f_locals) + + i = code.InteractiveConsole(d) + message = "Signal received : entering python shell.\nTraceback:\n" + message += "".join(traceback.format_stack(frame)) + i.interact(message) + + +def format_logprobs( + request_output: CompletionOutput, +) -> Optional[List[Dict[int, float]]]: + """Given a request output, format the logprobs if they exist.""" + output_logprobs = request_output.outputs[0].logprobs + if output_logprobs is None: + return None + + def extract_logprobs(logprobs: Dict[int, Logprob]) -> Dict[int, float]: + return {k: v.logprob for k, v in logprobs.items()} + + return [extract_logprobs(logprobs) for logprobs in output_logprobs] + + +def parse_args(parser: FlexibleArgumentParser): + parser = make_arg_parser(parser) + parser.add_argument("--attention-backend", type=str, help="The attention backend to use") + return parser.parse_args() + + +async def run_server(args, **uvicorn_kwargs) -> None: + logger.info("vLLM API server version %s", VLLM_VERSION) + logger.info("args: %s", args) + + temp_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM) # nosemgrep + temp_socket.bind(("", args.port)) + + def signal_handler(*_) -> None: + # Interrupt server on sigterm while initializing + raise KeyboardInterrupt("terminated") + + signal.signal(signal.SIGTERM, signal_handler) + + global engine_client + async with build_async_engine_client(args) as engine_client: + app = build_app(args) + + model_config = await engine_client.get_model_config() + init_app_state(engine_client, model_config, app.state, args) + + temp_socket.close() + app.include_router(router) + + shutdown_task = await serve_http( + app, + host=args.host, + port=args.port, + log_level=args.uvicorn_log_level, + timeout_keep_alive=TIMEOUT_KEEP_ALIVE, + ssl_keyfile=args.ssl_keyfile, + ssl_certfile=args.ssl_certfile, + ssl_ca_certs=args.ssl_ca_certs, + ssl_cert_reqs=args.ssl_cert_reqs, + **uvicorn_kwargs, + ) + + # NB: Await server shutdown only after the backend context is exited + await shutdown_task + + +if __name__ == "__main__": + check_unknown_startup_memory_usage() + + parser = FlexibleArgumentParser() + args = parse_args(parser) + if args.attention_backend is not None: + os.environ["VLLM_ATTENTION_BACKEND"] = args.attention_backend + asyncio.run(run_server(args)) diff --git a/model-engine/model_engine_server/infra/__init__.py b/model-engine/model_engine_server/infra/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/llm_engine_server/infra/gateways/__init__.py b/model-engine/model_engine_server/infra/gateways/__init__.py similarity index 68% rename from server/llm_engine_server/infra/gateways/__init__.py rename to model-engine/model_engine_server/infra/gateways/__init__.py index c7c5a2af..f8a3ee6e 100644 --- a/server/llm_engine_server/infra/gateways/__init__.py +++ b/model-engine/model_engine_server/infra/gateways/__init__.py @@ -1,14 +1,19 @@ from typing import Sequence +from .abs_file_storage_gateway import ABSFileStorageGateway +from .abs_filesystem_gateway import ABSFilesystemGateway +from .abs_llm_artifact_gateway import ABSLLMArtifactGateway +from .asb_inference_autoscaling_metrics_gateway import ASBInferenceAutoscalingMetricsGateway from .batch_job_orchestration_gateway import BatchJobOrchestrationGateway from .batch_job_progress_gateway import BatchJobProgressGateway from .celery_task_queue_gateway import CeleryTaskQueueGateway from .datadog_monitoring_metrics_gateway import DatadogMonitoringMetricsGateway +from .fake_model_primitive_gateway import FakeModelPrimitiveGateway from .fake_monitoring_metrics_gateway import FakeMonitoringMetricsGateway -from .filesystem_gateway import FilesystemGateway from .live_async_model_endpoint_inference_gateway import LiveAsyncModelEndpointInferenceGateway from .live_batch_job_orchestration_gateway import LiveBatchJobOrchestrationGateway from .live_batch_job_progress_gateway import LiveBatchJobProgressGateway +from .live_cron_job_gateway import LiveCronJobGateway from .live_docker_image_batch_job_gateway import LiveDockerImageBatchJobGateway from .live_model_endpoint_infra_gateway import LiveModelEndpointInfraGateway from .live_model_endpoints_schema_gateway import LiveModelEndpointsSchemaGateway @@ -17,23 +22,32 @@ ) from .live_sync_model_endpoint_inference_gateway import LiveSyncModelEndpointInferenceGateway from .model_endpoint_infra_gateway import ModelEndpointInfraGateway +from .redis_inference_autoscaling_metrics_gateway import RedisInferenceAutoscalingMetricsGateway from .s3_filesystem_gateway import S3FilesystemGateway +from .s3_llm_artifact_gateway import S3LLMArtifactGateway __all__: Sequence[str] = [ + "ABSFileStorageGateway", + "ABSFilesystemGateway", + "ABSLLMArtifactGateway", + "ASBInferenceAutoscalingMetricsGateway", "BatchJobOrchestrationGateway", "BatchJobProgressGateway", "CeleryTaskQueueGateway", "DatadogMonitoringMetricsGateway", + "FakeModelPrimitiveGateway", "FakeMonitoringMetricsGateway", - "FilesystemGateway", "LiveAsyncModelEndpointInferenceGateway", "LiveBatchJobOrchestrationGateway", "LiveBatchJobProgressGateway", + "LiveCronJobGateway", "LiveDockerImageBatchJobGateway", "LiveModelEndpointInfraGateway", "LiveModelEndpointsSchemaGateway", "LiveStreamingModelEndpointInferenceGateway", "LiveSyncModelEndpointInferenceGateway", "ModelEndpointInfraGateway", + "RedisInferenceAutoscalingMetricsGateway", "S3FilesystemGateway", + "S3LLMArtifactGateway", ] diff --git a/model-engine/model_engine_server/infra/gateways/abs_file_storage_gateway.py b/model-engine/model_engine_server/infra/gateways/abs_file_storage_gateway.py new file mode 100644 index 00000000..a12a0cb7 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/abs_file_storage_gateway.py @@ -0,0 +1,34 @@ +from typing import List, Optional + +from model_engine_server.domain.gateways.file_storage_gateway import ( + FileMetadata, + FileStorageGateway, +) +from model_engine_server.infra.gateways.abs_filesystem_gateway import ABSFilesystemGateway + + +class ABSFileStorageGateway(FileStorageGateway): + """ + Concrete implementation of a file storage gateway backed by ABS. + """ + + def __init__(self): + self.filesystem_gateway = ABSFilesystemGateway() + + async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]: + raise NotImplementedError("ABS not supported yet") + + async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: + raise NotImplementedError("ABS not supported yet") + + async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: + raise NotImplementedError("ABS not supported yet") + + async def upload_file(self, owner: str, filename: str, content: bytes) -> str: + raise NotImplementedError("ABS not supported yet") + + async def delete_file(self, owner: str, file_id: str) -> bool: + raise NotImplementedError("ABS not supported yet") + + async def list_files(self, owner: str) -> List[FileMetadata]: + raise NotImplementedError("ABS not supported yet") diff --git a/model-engine/model_engine_server/infra/gateways/abs_filesystem_gateway.py b/model-engine/model_engine_server/infra/gateways/abs_filesystem_gateway.py new file mode 100644 index 00000000..abf6f99e --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/abs_filesystem_gateway.py @@ -0,0 +1,48 @@ +import os +import re +from datetime import datetime, timedelta +from typing import IO + +import smart_open +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobSasPermissions, BlobServiceClient, generate_blob_sas +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway + + +class ABSFilesystemGateway(FilesystemGateway): + """ + Concrete implementation for interacting with a filesystem backed by Azure Blob Storage. + """ + + # uri should start with azure:// (as opposed to https://) unless the container is publicly accessible + def open(self, uri: str, mode: str = "rt", **kwargs) -> IO: + client = BlobServiceClient( + f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net", + DefaultAzureCredential(), + ) + transport_params = {"client": client} + return smart_open.open(uri, mode, transport_params=transport_params) + + def generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str: + match = re.search("^https://([^/]+)\.blob\.core\.windows\.net/([^/]+)/(.*?)$", uri) + assert match + + account_name, container_name, blob_name = match.group(1), match.group(2), match.group(3) + + blob_service_client = BlobServiceClient( + f"https://{account_name}.blob.core.windows.net", DefaultAzureCredential() + ) + user_delegation_key = blob_service_client.get_user_delegation_key( + datetime.utcnow(), datetime.utcnow() + timedelta(seconds=expiration) + ) + + sas_blob = generate_blob_sas( + account_name=account_name, + container_name=container_name, + blob_name=blob_name, + user_delegation_key=user_delegation_key, + permission=BlobSasPermissions(read=True, write=False, create=False), + expiry=datetime.utcnow() + timedelta(seconds=expiration), + **kwargs, + ) + return uri + "?" + sas_blob diff --git a/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py new file mode 100644 index 00000000..a1236138 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/abs_llm_artifact_gateway.py @@ -0,0 +1,85 @@ +import json +import os +from typing import Any, Dict, List + +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient, ContainerClient +from model_engine_server.common.config import get_model_cache_directory_name, hmi_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.utils.url import parse_attachment_url +from model_engine_server.domain.gateways import LLMArtifactGateway + +logger = make_logger(logger_name()) + + +def _get_abs_container_client(bucket: str) -> ContainerClient: + blob_service_client = BlobServiceClient( + f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net", + DefaultAzureCredential(), + ) + return blob_service_client.get_container_client(container=bucket) + + +class ABSLLMArtifactGateway(LLMArtifactGateway): + """ + Concrete implemention using Azure Blob Storage. + """ + + def list_files(self, path: str, **kwargs) -> List[str]: + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket = parsed_remote.bucket + key = parsed_remote.key + + container_client = _get_abs_container_client(bucket) + return list(container_client.list_blob_names(name_starts_with=key)) + + def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket = parsed_remote.bucket + key = parsed_remote.key + + container_client = _get_abs_container_client(bucket) + + downloaded_files: List[str] = [] + for blob in container_client.list_blobs(name_starts_with=key): + file_path_suffix = blob.name.replace(key, "").lstrip("/") + local_path = os.path.join(target_path, file_path_suffix).rstrip("/") + + if not overwrite and os.path.exists(local_path): + downloaded_files.append(local_path) + continue + + local_dir = "/".join(local_path.split("/")[:-1]) + if not os.path.exists(local_dir): + os.makedirs(local_dir) + + logger.info(f"Downloading {blob.name} to {local_path}") + with open(file=local_path, mode="wb") as f: + f.write(container_client.download_blob(blob.name).readall()) + downloaded_files.append(local_path) + return downloaded_files + + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: + parsed_remote = parse_attachment_url( + hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False + ) + account = parsed_remote.account + bucket = parsed_remote.bucket + fine_tuned_weights_prefix = parsed_remote.key + + container_client = _get_abs_container_client(bucket) + + model_files: List[str] = [] + model_cache_name = get_model_cache_directory_name(model_name) + prefix = f"{fine_tuned_weights_prefix}/{owner}/{model_cache_name}" + for blob_name in container_client.list_blob_names(name_starts_with=prefix): + model_files.append(f"https://{account}.blob.core.windows.net/{bucket}/{blob_name}") + return model_files + + def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket = parsed_remote.bucket + key = os.path.join(parsed_remote.key, "config.json") + + container_client = _get_abs_container_client(bucket) + return json.loads(container_client.download_blob(blob=key).readall()) diff --git a/model-engine/model_engine_server/infra/gateways/asb_inference_autoscaling_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/asb_inference_autoscaling_metrics_gateway.py new file mode 100644 index 00000000..6ab06a27 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/asb_inference_autoscaling_metrics_gateway.py @@ -0,0 +1,72 @@ +import os +from datetime import timedelta + +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.servicebus import ServiceBusClient, ServiceBusMessage +from azure.servicebus.management import ServiceBusAdministrationClient +from model_engine_server.domain.gateways.inference_autoscaling_metrics_gateway import ( + InferenceAutoscalingMetricsGateway, +) + +EXPIRY_SECONDS = 60 # 1 minute; this gets added to the cooldown time present in the keda ScaledObject to get total +# scaledown time. This also needs to be larger than the keda ScaledObject's refresh rate. +PREWARM_EXPIRY_SECONDS = 60 * 60 # 1 hour + + +def _get_servicebus_administration_client() -> ServiceBusAdministrationClient: + return ServiceBusAdministrationClient( + f"{os.getenv('SERVICEBUS_NAMESPACE')}.servicebus.windows.net", + credential=DefaultAzureCredential(), + ) + + +class ASBInferenceAutoscalingMetricsGateway(InferenceAutoscalingMetricsGateway): + @staticmethod + def _find_queue_name(endpoint_id: str): + # Keep in line with keda scaled object yaml + return f"launch-endpoint-autoscaling.{endpoint_id}" + + async def _emit_metric(self, endpoint_id: str, expiry_time: int): + queue_name = self._find_queue_name(endpoint_id) + + servicebus_namespace = os.getenv("SERVICEBUS_NAMESPACE") + if servicebus_namespace is None: + raise ValueError("SERVICEBUS_NAMESPACE env var must be set in Azure") + + with ServiceBusClient( + fully_qualified_namespace=f"{servicebus_namespace}.servicebus.windows.net", + credential=DefaultAzureCredential(), + ) as servicebus_client: + sender = servicebus_client.get_queue_sender(queue_name=queue_name) + with sender: + message = ServiceBusMessage( + "message", time_to_live=timedelta(seconds=expiry_time) + ) # we only care about the length of the queue, not the message values + sender.send_messages(message=message) + + receiver = servicebus_client.get_queue_receiver(queue_name=queue_name) + with receiver: + receiver.peek_messages(max_message_count=1, timeout=1) + + async def emit_inference_autoscaling_metric(self, endpoint_id: str): + await self._emit_metric(endpoint_id, EXPIRY_SECONDS) + + async def emit_prewarm_metric(self, endpoint_id: str): + await self._emit_metric(endpoint_id, PREWARM_EXPIRY_SECONDS) + + async def create_or_update_resources(self, endpoint_id: str): + queue_name = self._find_queue_name(endpoint_id) + with _get_servicebus_administration_client() as client: + try: + client.create_queue(queue_name=queue_name) + except ResourceExistsError: + pass + + async def delete_resources(self, endpoint_id: str): + queue_name = self._find_queue_name(endpoint_id) + with _get_servicebus_administration_client() as client: + try: + client.delete_queue(queue_name=queue_name) + except ResourceNotFoundError: + pass diff --git a/server/llm_engine_server/infra/gateways/batch_job_orchestration_gateway.py b/model-engine/model_engine_server/infra/gateways/batch_job_orchestration_gateway.py similarity index 94% rename from server/llm_engine_server/infra/gateways/batch_job_orchestration_gateway.py rename to model-engine/model_engine_server/infra/gateways/batch_job_orchestration_gateway.py index 57c40394..5ce0bb1e 100644 --- a/server/llm_engine_server/infra/gateways/batch_job_orchestration_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/batch_job_orchestration_gateway.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Dict -from llm_engine_server.domain.entities import BatchJobSerializationFormat +from model_engine_server.domain.entities import BatchJobSerializationFormat class BatchJobOrchestrationGateway(ABC): diff --git a/server/llm_engine_server/infra/gateways/batch_job_progress_gateway.py b/model-engine/model_engine_server/infra/gateways/batch_job_progress_gateway.py similarity index 92% rename from server/llm_engine_server/infra/gateways/batch_job_progress_gateway.py rename to model-engine/model_engine_server/infra/gateways/batch_job_progress_gateway.py index ab20bd34..e1da816d 100644 --- a/server/llm_engine_server/infra/gateways/batch_job_progress_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/batch_job_progress_gateway.py @@ -1,6 +1,6 @@ from abc import ABC, abstractmethod -from llm_engine_server.domain.entities import BatchJobProgress +from model_engine_server.domain.entities import BatchJobProgress class BatchJobProgressGateway(ABC): diff --git a/server/llm_engine_server/infra/gateways/celery_task_queue_gateway.py b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py similarity index 58% rename from server/llm_engine_server/infra/gateways/celery_task_queue_gateway.py rename to model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py index 7017e6ef..676a1274 100644 --- a/server/llm_engine_server/infra/gateways/celery_task_queue_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/celery_task_queue_gateway.py @@ -1,31 +1,42 @@ from typing import Any, Dict, List, Optional -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.common.dtos.tasks import ( +import botocore +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, GetAsyncTaskV1Response, TaskStatus, ) -from llm_engine_server.core.celery import TaskVisibility, celery_app -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway +from model_engine_server.core.celery import TaskVisibility, celery_app +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import InvalidRequestException +from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) +backend_protocol = "abs" if infra_config().cloud_provider == "azure" else "s3" celery_redis = celery_app( None, - s3_bucket=ml_infra_config().s3_bucket, + s3_bucket=infra_config().s3_bucket, broker_type=str(BrokerType.REDIS.value), + backend_protocol=backend_protocol, ) celery_redis_24h = celery_app( None, - s3_bucket=ml_infra_config().s3_bucket, + s3_bucket=infra_config().s3_bucket, broker_type=str(BrokerType.REDIS.value), task_visibility=TaskVisibility.VISIBILITY_24H, + backend_protocol=backend_protocol, ) celery_sqs = celery_app( - None, s3_bucket=ml_infra_config().s3_bucket, broker_type=str(BrokerType.SQS.value) + None, + s3_bucket=infra_config().s3_bucket, + broker_type=str(BrokerType.SQS.value), + backend_protocol=backend_protocol, +) +celery_servicebus = celery_app( + None, broker_type=str(BrokerType.SERVICEBUS.value), backend_protocol=backend_protocol ) @@ -36,6 +47,7 @@ def __init__(self, broker_type: BrokerType): BrokerType.SQS, BrokerType.REDIS, BrokerType.REDIS_24H, + BrokerType.SERVICEBUS, ] def _get_celery_dest(self): @@ -43,8 +55,10 @@ def _get_celery_dest(self): return celery_sqs elif self.broker_type == BrokerType.REDIS_24H: return celery_redis_24h - else: # self.broker_type == BrokerType.REDIS + elif self.broker_type == BrokerType.REDIS: return celery_redis + else: + return celery_servicebus def send_task( self, @@ -55,16 +69,18 @@ def send_task( expires: Optional[int] = None, ) -> CreateAsyncTaskV1Response: celery_dest = self._get_celery_dest() - logger.info( - f"Sending task {task_name} with args {args} kwargs {kwargs} to queue {queue_name}" - ) - res = celery_dest.send_task( - name=task_name, - args=args, - kwargs=kwargs, - queue=queue_name, - ) - logger.info(f"Response from sending task {task_name}: {res}") + + try: + res = celery_dest.send_task( + name=task_name, + args=args, + kwargs=kwargs, + queue=queue_name, + ) + except botocore.exceptions.ClientError as e: + logger.exception(f"Error sending task to queue {queue_name}: {e}") + raise InvalidRequestException(f"Error sending celery task: {e}") + logger.info(f"Task {res.id} sent to queue {queue_name} from gateway") # pragma: no cover return CreateAsyncTaskV1Response(task_id=res.id) def get_task(self, task_id: str) -> GetAsyncTaskV1Response: diff --git a/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py new file mode 100644 index 00000000..93a73970 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py @@ -0,0 +1,92 @@ +from typing import List, Optional + +from datadog import statsd +from model_engine_server.common.dtos.llms import TokenUsage +from model_engine_server.core.config import infra_config +from model_engine_server.domain.gateways.monitoring_metrics_gateway import ( + MetricMetadata, + MonitoringMetricsGateway, +) + + +def get_model_tags(model_name: Optional[str]) -> List[str]: + """ + Returns a tag for the model name and whether it is a finetuned model + """ + tags = [] + if model_name: + parts = model_name.split(".") + tags.extend([f"model_name:{parts[0]}"]) + return tags + + +class DatadogMonitoringMetricsGateway(MonitoringMetricsGateway): + def __init__(self, prefix: str = "model_engine"): + self.prefix = prefix + self.tags = [f"env:{infra_config().env}"] + + def emit_attempted_build_metric(self): + statsd.increment("scale_launch.service_builder.attempt", tags=self.tags) + + def emit_successful_build_metric(self): + statsd.increment("scale_launch.service_builder.success", tags=self.tags) + + def emit_build_time_metric(self, duration_seconds: float): + statsd.distribution( + "scale_launch.service_builder.endpoint_build_time", duration_seconds, tags=self.tags + ) + + def emit_image_build_cache_hit_metric(self, image_type: str): + statsd.increment( + f"scale_launch.service_builder.{image_type}_image_cache_hit", tags=self.tags + ) + + def emit_image_build_cache_miss_metric(self, image_type: str): + statsd.increment( + f"scale_launch.service_builder.{image_type}_image_cache_miss", tags=self.tags + ) + + def emit_docker_failed_build_metric(self): + statsd.increment("scale_launch.service_builder.docker_failed", tags=self.tags) + + def emit_database_cache_hit_metric(self): + statsd.increment("scale_launch.database_cache.hit", tags=self.tags) + + def emit_database_cache_miss_metric(self): + statsd.increment("scale_launch.database_cache.miss", tags=self.tags) + + def _format_call_tags(self, metadata: MetricMetadata) -> List[str]: + tags = self.tags + tags.extend(get_model_tags(metadata.model_name)) + return tags + + def emit_route_call_metric(self, route: str, metadata: MetricMetadata): + statsd.increment(f"{self.prefix}.{route}.call", tags=self._format_call_tags(metadata)) + + def emit_token_count_metrics(self, token_usage: TokenUsage, metadata: MetricMetadata): + tags = self._format_call_tags(metadata) + + token_count_metric = f"{self.prefix}.token_count" + statsd.increment( + f"{token_count_metric}.prompt", (token_usage.num_prompt_tokens or 0), tags=tags + ) + statsd.increment( + f"{token_count_metric}.completion", (token_usage.num_completion_tokens or 0), tags=tags + ) + statsd.increment(f"{token_count_metric}.total", token_usage.num_total_tokens, tags=tags) + + total_tokens_per_second = f"{self.prefix}.total_tokens_per_second" + statsd.histogram(total_tokens_per_second, token_usage.total_tokens_per_second, tags=tags) + + time_to_first_token = f"{self.prefix}.time_to_first_token" + if token_usage.time_to_first_token is not None: + statsd.distribution(time_to_first_token, token_usage.time_to_first_token, tags=tags) + + inter_token_latency = f"{self.prefix}.inter_token_latency" + if token_usage.inter_token_latency is not None: + statsd.distribution(inter_token_latency, token_usage.inter_token_latency, tags=tags) + + def emit_http_call_error_metrics(self, endpoint_name: str, error_code: int): + tags = self.tags + tags.extend([f"endpoint_name:{endpoint_name}", f"error_code:{error_code}"]) + statsd.increment(f"{self.prefix}.upstream_sync_error", tags=tags) diff --git a/model-engine/model_engine_server/infra/gateways/dns_resolver.py b/model-engine/model_engine_server/infra/gateways/dns_resolver.py new file mode 100644 index 00000000..0579f98a --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/dns_resolver.py @@ -0,0 +1,22 @@ +import socket +from typing import Union + + +def resolve_dns(host: str, port: Union[str, int] = "http") -> str: + """ + Returns an IP address of the given host, e.g. "256.256.256.256" for IPv4, or + "[0000:0000:0000::0000]" for IPv6. You should be able to just substitute this into a URL. + """ + addrinfo = socket.getaddrinfo(host, port) + if len(addrinfo) == 0: + raise ValueError("Host not found.") + # Probably just need the first one + socket_type = addrinfo[0][0] + ip = addrinfo[0][4][0] + # Do I want to do anything with port? it probably ends up being the default (e.g. 80 for http, 443 for https) + if socket_type == socket.AF_INET6: + return f"[{ip}]" + elif socket_type == socket.AF_INET: + return ip + else: + raise ValueError("Unknown socket type.") diff --git a/server/llm_engine_server/infra/gateways/fake_model_primitive_gateway.py b/model-engine/model_engine_server/infra/gateways/fake_model_primitive_gateway.py similarity index 82% rename from server/llm_engine_server/infra/gateways/fake_model_primitive_gateway.py rename to model-engine/model_engine_server/infra/gateways/fake_model_primitive_gateway.py index 18bebf0b..e095fa12 100644 --- a/server/llm_engine_server/infra/gateways/fake_model_primitive_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/fake_model_primitive_gateway.py @@ -1,7 +1,7 @@ from typing import Optional -from llm_engine_server.domain.entities import ModelBundleFrameworkType -from llm_engine_server.domain.gateways import ModelPrimitiveGateway +from model_engine_server.domain.entities import ModelBundleFrameworkType +from model_engine_server.domain.gateways import ModelPrimitiveGateway class FakeModelPrimitiveGateway(ModelPrimitiveGateway): diff --git a/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py new file mode 100644 index 00000000..25bf45fa --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py @@ -0,0 +1,81 @@ +from collections import defaultdict + +from model_engine_server.common.dtos.llms import TokenUsage +from model_engine_server.domain.gateways.monitoring_metrics_gateway import ( + MetricMetadata, + MonitoringMetricsGateway, +) + + +class FakeMonitoringMetricsGateway(MonitoringMetricsGateway): + def __init__(self): + self.attempted_build = 0 + self.successful_build = 0 + self.build_time_seconds = 0 + self.image_build_cache_hit = defaultdict(int) + self.image_build_cache_miss = defaultdict(int) + self.docker_failed_build = 0 + self.attempted_hook = defaultdict(int) + self.successful_hook = defaultdict(int) + self.database_cache_hit = 0 + self.database_cache_miss = 0 + self.route_call = defaultdict(int) + self.token_count = 0 + self.total_tokens_per_second = 0 + self.sync_call_timeout = defaultdict(int) + + def reset(self): + self.attempted_build = 0 + self.successful_build = 0 + self.build_time_seconds = 0 + self.image_build_cache_hit = defaultdict(int) + self.image_build_cache_miss = defaultdict(int) + self.docker_failed_build = 0 + self.attempted_hook = defaultdict(int) + self.successful_hook = defaultdict(int) + self.database_cache_hit = 0 + self.database_cache_miss = 0 + self.route_call = defaultdict(int) + self.token_count = 0 + self.total_tokens_per_second = 0 + self.sync_call_timeout = defaultdict(int) + + def emit_attempted_build_metric(self): + self.attempted_build += 1 + + def emit_successful_build_metric(self): + self.successful_build += 1 + + def emit_build_time_metric(self, duration_seconds: float): + self.build_time_seconds += duration_seconds + + def emit_image_build_cache_hit_metric(self, image_type: str): + self.image_build_cache_hit[image_type] += 1 + + def emit_image_build_cache_miss_metric(self, image_type: str): + self.image_build_cache_miss[image_type] += 1 + + def emit_docker_failed_build_metric(self): + self.docker_failed_build += 1 + + def emit_attempted_post_inference_hook(self, hook: str): + self.attempted_hook[hook] += 1 + + def emit_successful_post_inference_hook(self, hook: str): + self.successful_hook[hook] += 1 + + def emit_database_cache_hit_metric(self): + self.database_cache_hit += 1 + + def emit_database_cache_miss_metric(self): + self.database_cache_miss += 1 + + def emit_route_call_metric(self, route: str, _metadata: MetricMetadata): + self.route_call[route] += 1 + + def emit_token_count_metrics(self, token_usage: TokenUsage, _metadata: MetricMetadata): + self.token_count += token_usage.num_total_tokens + self.total_tokens_per_second = token_usage.total_tokens_per_second + + def emit_http_call_error_metrics(self, endpoint_name: str, error_code: int): + self.sync_call_timeout[(endpoint_name, error_code)] += 1 diff --git a/server/llm_engine_server/infra/gateways/filesystem_gateway.py b/model-engine/model_engine_server/infra/gateways/filesystem_gateway.py similarity index 100% rename from server/llm_engine_server/infra/gateways/filesystem_gateway.py rename to model-engine/model_engine_server/infra/gateways/filesystem_gateway.py diff --git a/server/llm_engine_server/infra/gateways/k8s_resource_parser.py b/model-engine/model_engine_server/infra/gateways/k8s_resource_parser.py similarity index 85% rename from server/llm_engine_server/infra/gateways/k8s_resource_parser.py rename to model-engine/model_engine_server/infra/gateways/k8s_resource_parser.py index a8626f65..947a734d 100644 --- a/server/llm_engine_server/infra/gateways/k8s_resource_parser.py +++ b/model-engine/model_engine_server/infra/gateways/k8s_resource_parser.py @@ -1,10 +1,8 @@ import hashlib +import math import re from typing import Union -MAX_CONCURRENCY_TO_TARGET_CONCURRENCY_RATIO = 2.0 - - # found this regex floating around somewhere, probably validates k8s requests in general: # '^([+-]?[0-9.]+)([eEinumkKMGTP]*[-+]?[0-9]*)$' @@ -57,12 +55,12 @@ def parse_mem_request(req: str): def get_node_port(service_name: str) -> int: """Hashes the service name to a port number in the range [30000, 32767]""" - return int(hashlib.md5(service_name.encode()).hexdigest(), 16) % (32768 - 30000) + 30000 + return int(hashlib.sha256(service_name.encode()).hexdigest(), 16) % (32768 - 30000) + 30000 def get_target_concurrency_from_per_worker_value(per_worker: int) -> float: """Returns the target concurrency given a per-worker value""" - return per_worker / MAX_CONCURRENCY_TO_TARGET_CONCURRENCY_RATIO + return per_worker def get_per_worker_value_from_target_concurrency(concurrency: Union[str, int, float]) -> int: @@ -70,13 +68,7 @@ def get_per_worker_value_from_target_concurrency(concurrency: Union[str, int, fl Inverse of get_target_concurrency_from_per_worker_value """ - return int( - round( - parse_cpu_request(str(concurrency)) - * MAX_CONCURRENCY_TO_TARGET_CONCURRENCY_RATIO - / 1000.0 - ) - ) + return int(math.ceil(parse_cpu_request(str(concurrency)) / 1000.0)) def format_bytes(num_bytes) -> str: diff --git a/server/llm_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py similarity index 71% rename from server/llm_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py index be5f5537..f1c8b4f9 100644 --- a/server/llm_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_async_model_endpoint_inference_gateway.py @@ -1,15 +1,16 @@ import json +from datetime import datetime -from llm_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME +from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, EndpointPredictV1Request, GetAsyncTaskV1Response, ) -from llm_engine_server.domain.gateways.async_model_endpoint_inference_gateway import ( +from model_engine_server.domain.gateways.async_model_endpoint_inference_gateway import ( AsyncModelEndpointInferenceGateway, ) -from llm_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway +from model_engine_server.domain.gateways.task_queue_gateway import TaskQueueGateway class LiveAsyncModelEndpointInferenceGateway(AsyncModelEndpointInferenceGateway): @@ -30,14 +31,14 @@ def create_task( *, task_name: str = DEFAULT_CELERY_TASK_NAME, ) -> CreateAsyncTaskV1Response: - # Use json.loads instead of predict_request.dict() because we have overridden the '__root__' - # key in some fields, and __root__ overriding only reflects in the json() output. + # Use json.loads instead of predict_request.dict() because we have overridden the 'root' + # key in some fields, and root overriding only reflects in the json() output. predict_args = json.loads(predict_request.json()) send_task_response = self.task_queue_gateway.send_task( task_name=task_name, queue_name=topic, - args=[predict_args, predict_request.return_pickled], + args=[predict_args, datetime.now(), predict_request.return_pickled], expires=task_timeout_seconds, ) return CreateAsyncTaskV1Response(task_id=send_task_response.task_id) diff --git a/server/llm_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py b/model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py similarity index 77% rename from server/llm_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py index 09eade35..93b87ce8 100644 --- a/server/llm_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_batch_job_orchestration_gateway.py @@ -1,26 +1,32 @@ from typing import Dict from kubernetes_asyncio.client.rest import ApiException -from llm_engine_server.common.config import hmi_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import BatchJobSerializationFormat -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.infra.gateways import BatchJobOrchestrationGateway -from llm_engine_server.infra.gateways.live_docker_image_batch_job_gateway import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.env_vars import GIT_TAG +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) +from model_engine_server.domain.entities import BatchJobSerializationFormat +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways import BatchJobOrchestrationGateway +from model_engine_server.infra.gateways.live_docker_image_batch_job_gateway import ( BATCH_JOB_TTL_SECONDS_AFTER_FINISHED, ) -from llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( get_kubernetes_batch_client, load_k8s_yaml, maybe_load_kube_config, ) -from llm_engine_server.infra.gateways.resources.k8s_resource_types import ( +from model_engine_server.infra.gateways.resources.k8s_resource_types import ( BatchJobOrchestrationJobArguments, ) SHUTDOWN_GRACE_PERIOD = 60 -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class LiveBatchJobOrchestrationGateway(BatchJobOrchestrationGateway): @@ -53,6 +59,8 @@ async def create_batch_job_orchestrator( BATCH_JOB_TIMEOUT=timeout_seconds, BATCH_JOB_MAX_RUNTIME=int(timeout_seconds + SHUTDOWN_GRACE_PERIOD), BATCH_JOB_TTL_SECONDS_AFTER_FINISHED=BATCH_JOB_TTL_SECONDS_AFTER_FINISHED, + GIT_TAG=GIT_TAG, + REQUEST_ID=LoggerTagManager.get(LoggerTagKey.REQUEST_ID) or "", ) resource_key = "batch-job-orchestration-job.yaml" deployment_spec = load_k8s_yaml(resource_key, substitution_kwargs) diff --git a/server/llm_engine_server/infra/gateways/live_batch_job_progress_gateway.py b/model-engine/model_engine_server/infra/gateways/live_batch_job_progress_gateway.py similarity index 67% rename from server/llm_engine_server/infra/gateways/live_batch_job_progress_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_batch_job_progress_gateway.py index 1db10e0d..ef7c506e 100644 --- a/server/llm_engine_server/infra/gateways/live_batch_job_progress_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_batch_job_progress_gateway.py @@ -1,13 +1,14 @@ -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import BatchJobProgress -from llm_engine_server.infra.gateways import BatchJobProgressGateway, FilesystemGateway +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.entities import BatchJobProgress +from model_engine_server.infra.gateways import BatchJobProgressGateway +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) def get_batch_job_progress_location(user_id: str, batch_job_id: str): - return f"s3://{ml_infra_config().s3_bucket}/batch_job_progress/{user_id}/{batch_job_id}" + return f"s3://{infra_config().s3_bucket}/batch_job_progress/{user_id}/{batch_job_id}" class LiveBatchJobProgressGateway(BatchJobProgressGateway): @@ -21,7 +22,7 @@ def get_progress(self, owner: str, batch_job_id: str) -> BatchJobProgress: ) try: with self.filesystem_gateway.open( - progress_location, aws_profile=ml_infra_config().profile_ml_worker + progress_location, aws_profile=infra_config().profile_ml_worker ) as f: progress = BatchJobProgress.parse_raw(f.read()) except Exception: @@ -39,6 +40,6 @@ def update_progress(self, owner: str, batch_job_id: str, progress: BatchJobProgr user_id=owner, batch_job_id=batch_job_id ) with self.filesystem_gateway.open( - progress_location, "w", aws_profile=ml_infra_config().profile_ml_worker + progress_location, "w", aws_profile=infra_config().profile_ml_worker ) as f: f.write(progress.json()) diff --git a/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py b/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py new file mode 100644 index 00000000..2970bead --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/live_cron_job_gateway.py @@ -0,0 +1,177 @@ +from typing import Any, Dict, List, Optional + +from kubernetes_asyncio.client.rest import ApiException +from model_engine_server.common import dict_not_none +from model_engine_server.common.config import hmi_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.gateways.cron_job_gateway import CronJobGateway +from model_engine_server.infra.gateways.live_docker_image_batch_job_gateway import ( + LAUNCH_JOB_ID_LABEL_SELECTOR, + _parse_job_status_from_k8s_obj, + make_job_id_to_pods_mapping, +) +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( + get_kubernetes_batch_client, + get_kubernetes_core_client, + load_k8s_yaml, + maybe_load_kube_config, +) +from model_engine_server.infra.gateways.resources.k8s_resource_types import CronTriggerArguments + +BATCH_CURL_JOB_ACTIVE_DEADLINE_SECONDS = 10 + +logger = make_logger(logger_name()) + + +def _k8s_cron_job_name_from_id(trigger_id: str): + trigger_id_suffix = trigger_id[5:] # suffix following "trig_" contains xid + return f"launch-trigger-{trigger_id_suffix}" + + +class LiveCronJobGateway(CronJobGateway): + def __init__(self): + pass + + async def create_cronjob( + self, + *, + request_host: str, + trigger_id: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Dict[str, str], + ) -> None: + await maybe_load_kube_config() + + batch_client = get_kubernetes_batch_client() + + cron_job_name = _k8s_cron_job_name_from_id(trigger_id) + + cron_trigger_key = "cron-trigger.yaml" + substitution_kwargs = CronTriggerArguments( + HOST=request_host, + NAME=cron_job_name, + CREATED_BY=created_by, + OWNER=owner, + TEAM=default_job_metadata["team"], + PRODUCT=default_job_metadata["product"], + TRIGGER_ID=trigger_id, + CRON_SCHEDULE=cron_schedule, + DOCKER_IMAGE_BATCH_JOB_BUNDLE_ID=docker_image_batch_job_bundle_id, + JOB_CONFIG=self._format_dict_template_args(default_job_config or {}), + JOB_METADATA=self._format_dict_template_args(default_job_metadata), + BATCH_CURL_JOB_ACTIVE_DEADLINE_SECONDS=BATCH_CURL_JOB_ACTIVE_DEADLINE_SECONDS, + ) + cron_job_body = load_k8s_yaml(cron_trigger_key, substitution_kwargs) + + try: + await batch_client.create_namespaced_cron_job( + namespace=hmi_config.endpoint_namespace, body=cron_job_body + ) + except ApiException as exc: + logger.exception( + f"Exception encountered when creating batch cron job for docker image batch job bundle id '{docker_image_batch_job_bundle_id}' for {owner}" + ) + raise EndpointResourceInfraException from exc + + async def list_jobs( + self, + *, + owner: str, + trigger_id: Optional[str], + ) -> List[DockerImageBatchJob]: + await maybe_load_kube_config() + + batch_client = get_kubernetes_batch_client() + + try: + label_selector = f"trigger_id={trigger_id}" if trigger_id else f"owner={owner}" + jobs = await batch_client.list_namespaced_job( + namespace=hmi_config.endpoint_namespace, + label_selector=label_selector, + ) + except ApiException as exc: + logger.exception("Got an exception when trying to list the Jobs") + raise EndpointResourceInfraException from exc + + core_client = get_kubernetes_core_client() + + try: + label_selector = f"trigger_id={trigger_id}" if trigger_id else f"owner={owner},job-name" + pods = await core_client.list_namespaced_pod( + namespace=hmi_config.endpoint_namespace, + label_selector=label_selector, + ) + except ApiException as exc: + logger.exception("Got an exception when trying to list the Pods") + raise EndpointResourceInfraException from exc + + pods_per_job = make_job_id_to_pods_mapping(pods.items) + + return [ + DockerImageBatchJob( + id=job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR), + created_by=job.metadata.labels.get("created_by"), + owner=job.metadata.labels.get("owner"), + created_at=job.metadata.creation_timestamp, + completed_at=job.status.completion_time, + status=_parse_job_status_from_k8s_obj( + job, pods_per_job[job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR)] + ), + ) + for job in jobs.items + ] + + async def update_cronjob( + self, + *, + trigger_id: str, + cron_schedule: Optional[str], + suspend: Optional[bool], + ) -> None: + await maybe_load_kube_config() + + batch_client = get_kubernetes_batch_client() + + cron_job_name = _k8s_cron_job_name_from_id(trigger_id) + partial_body = dict(spec=dict_not_none(schedule=cron_schedule, suspend=suspend)) + + try: + await batch_client.patch_namespaced_cron_job( + name=cron_job_name, + namespace=hmi_config.endpoint_namespace, + body=partial_body, + ) + except ApiException: + logger.exception( + f"Exception encountered when patching batch cron job for trigger id '{trigger_id}', requested object likely does not exist" + ) + + async def delete_cronjob( + self, + *, + trigger_id: str, + ) -> None: + await maybe_load_kube_config() + + batch_client = get_kubernetes_batch_client() + + cron_job_name = _k8s_cron_job_name_from_id(trigger_id) + + try: + await batch_client.delete_namespaced_cron_job( + name=cron_job_name, namespace=hmi_config.endpoint_namespace + ) + except ApiException: + logger.exception( + f"Exception encountered when deleting batch cron job for trigger id '{trigger_id}', requested object likely does not exist" + ) + + @staticmethod + def _format_dict_template_args(obj: Dict[str, Any]) -> str: + return f"{obj}".replace("'", '"') diff --git a/server/llm_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py b/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py similarity index 63% rename from server/llm_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py index 2c2f7a88..c3a86326 100644 --- a/server/llm_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_docker_image_batch_job_gateway.py @@ -1,38 +1,45 @@ import os import re +from collections import defaultdict from pathlib import Path from typing import Any, Dict, List, Optional, Tuple, TypedDict, Union from kubernetes_asyncio.client.models.v1_job import V1Job +from kubernetes_asyncio.client.models.v1_pod import V1Pod from kubernetes_asyncio.client.rest import ApiException -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests -from llm_engine_server.common.serialization_utils import python_json_to_b64 -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities.batch_job_entity import BatchJobStatus, DockerImageBatchJob -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.domain.gateways.docker_image_batch_job_gateway import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.serialization_utils import python_json_to_b64 +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import ( + LoggerTagKey, + LoggerTagManager, + logger_name, + make_logger, +) +from model_engine_server.domain.entities.batch_job_entity import BatchJobStatus, DockerImageBatchJob +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.gateways.docker_image_batch_job_gateway import ( DockerImageBatchJobGateway, ) -from llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( get_kubernetes_batch_client, + get_kubernetes_core_client, load_k8s_yaml, maybe_load_kube_config, ) -from llm_engine_server.infra.gateways.resources.k8s_resource_types import ( +from model_engine_server.infra.gateways.resources.k8s_resource_types import ( DictStrStr, DockerImageBatchJobCpuArguments, DockerImageBatchJobGpuArguments, ) from xid import XID -DEFAULT_MOUNT_LOCATION = "/restricted_llm_engine/batch_payload.json" +DEFAULT_MOUNT_LOCATION = "/restricted_launch/batch_payload.json" # Must match resources/docker...{cpu,gpu}.yaml's label selector -LLM_ENGINE_JOB_ID_LABEL_SELECTOR = "llm_engine_job_id" +LAUNCH_JOB_ID_LABEL_SELECTOR = "launch_job_id" OWNER_LABEL_SELECTOR = "owner" - ENV: str = os.environ.get("DD_ENV") # type: ignore GIT_TAG: str = os.environ.get("GIT_TAG") # type: ignore SERVICE_CONFIG_PATH: str = os.environ.get("DEPLOY_SERVICE_CONFIG_PATH") # type: ignore @@ -44,10 +51,10 @@ Path(__file__).parent.absolute() / "resources/docker_image_batch_job_gpu.yaml" ) -BATCH_JOB_MAX_RUNTIME_SECONDS = 43200 # 12 hours +BATCH_JOB_MAX_RUNTIME_SECONDS = 86400 * 7 # 7 days BATCH_JOB_TTL_SECONDS_AFTER_FINISHED = 86400 * 3 # 3 days -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class K8sEnvDict(TypedDict): @@ -56,7 +63,7 @@ class K8sEnvDict(TypedDict): def _get_job_id(): - return f"job-{XID().string()}" + return f"ft-{XID().string()}" def _check_batch_job_id_valid(job_id: str): @@ -82,7 +89,41 @@ def _add_list_values( def _k8s_job_name_from_id(job_id: str): # "di" stands for "docker image" btw - return f"llm-engine-di-batch-job-{job_id}" + return f"launch-di-batch-job-{job_id}" + + +def _parse_job_status_from_k8s_obj(job: V1Job, pods: List[V1Pod]) -> BatchJobStatus: + status = job.status + # these counts are the number of pods in some given status + if status.failed is not None and status.failed > 0: + return BatchJobStatus.FAILURE + if status.succeeded is not None and status.succeeded > 0: + return BatchJobStatus.SUCCESS + if status.ready is not None and status.ready > 0: + return BatchJobStatus.RUNNING # empirically this doesn't happen + if status.active is not None and status.active > 0: + for pod in pods: + # In case there are multiple pods for a given job (e.g. if a pod gets shut down) + # let's interpret the job as running if any of the pods are running + # I haven't empirically seen this, but guard against it just in case. + if pod.status.phase == "Running": + return BatchJobStatus.RUNNING + return BatchJobStatus.PENDING + return BatchJobStatus.PENDING + + +def make_job_id_to_pods_mapping(pods: List[V1Pod]) -> defaultdict: + """ + Returns a defaultdict mapping job IDs to pods + """ + job_id_to_pods_mapping = defaultdict(list) + for pod in pods: + job_id = pod.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR) + if job_id is not None: + job_id_to_pods_mapping[job_id].append(pod) + else: + logger.warning(f"Pod {pod.metadata.name} has no job ID label") + return job_id_to_pods_mapping class LiveDockerImageBatchJobGateway(DockerImageBatchJobGateway): @@ -102,6 +143,9 @@ async def create_docker_image_batch_job( resource_requests: CreateDockerImageBatchJobResourceRequests, labels: Dict[str, str], mount_location: Optional[str], + annotations: Optional[Dict[str, str]] = None, + override_job_max_runtime_s: Optional[int] = None, + num_workers: Optional[int] = 1, ) -> str: await maybe_load_kube_config() @@ -116,7 +160,11 @@ async def create_docker_image_batch_job( created_by=created_by, owner=owner, labels=labels, + annotations=annotations, + override_job_max_runtime_s=override_job_max_runtime_s, + num_workers=num_workers, ) + logger.info(resource_spec) batch_client = get_kubernetes_batch_client() @@ -144,10 +192,14 @@ def _generate_job_spec( created_by: str, owner: str, labels: Dict[str, str], + annotations: Optional[Dict[str, str]] = None, + override_job_max_runtime_s: Optional[int] = None, + num_workers: Optional[int] = 1, ) -> Tuple[str, Dict[str, Any]]: job_id = _get_job_id() job_name = _k8s_job_name_from_id(job_id) # why do we even have job_name and id job_config_b64encoded = python_json_to_b64(job_config) + job_runtime_limit = override_job_max_runtime_s or BATCH_JOB_MAX_RUNTIME_SECONDS storage = resource_requests.storage storage_dict = DictStrStr("") if storage is not None: @@ -169,23 +221,27 @@ def _generate_job_spec( CREATED_BY=created_by, OWNER=owner, JOB_ID=job_id, + GIT_TAG=GIT_TAG, # Batch Job Arguments - BATCH_JOB_MAX_RUNTIME=BATCH_JOB_MAX_RUNTIME_SECONDS, + BATCH_JOB_MAX_RUNTIME=job_runtime_limit, BATCH_JOB_TTL_SECONDS_AFTER_FINISHED=BATCH_JOB_TTL_SECONDS_AFTER_FINISHED, - IMAGE=f"{ml_infra_config().docker_repo_prefix}/{repo}:{tag}", + IMAGE=f"{infra_config().docker_repo_prefix}/{repo}:{tag}", COMMAND=command, CPUS=str(resource_requests.cpus), MEMORY=str(resource_requests.memory), STORAGE_DICT=storage_dict, MOUNT_PATH=mount_path, - INPUT_LOCATION="--input-local", # TODO when we enable mounting remote s3files should be "--input-remote" + INPUT_LOCATION="--input-local", + # TODO when we enable mounting remote s3files should be "--input-remote" S3_FILE="unused", LOCAL_FILE_NAME=mount_location, FILE_CONTENTS_B64ENCODED=job_config_b64encoded, - AWS_ROLE=ml_infra_config().profile_ml_inference_worker, + AWS_ROLE=infra_config().profile_ml_inference_worker, # GPU Arguments GPU_TYPE=resource_requests.gpu_type.value, GPUS=resource_requests.gpus or 1, + REQUEST_ID=LoggerTagManager.get(LoggerTagKey.REQUEST_ID) or "", + BATCH_JOB_NUM_WORKERS=num_workers or 1, ) else: resource_key = "docker-image-batch-job-cpu.yaml" @@ -198,20 +254,24 @@ def _generate_job_spec( CREATED_BY=created_by, OWNER=owner, JOB_ID=job_id, + GIT_TAG=GIT_TAG, # Batch Job Arguments - BATCH_JOB_MAX_RUNTIME=BATCH_JOB_MAX_RUNTIME_SECONDS, + BATCH_JOB_MAX_RUNTIME=job_runtime_limit, BATCH_JOB_TTL_SECONDS_AFTER_FINISHED=BATCH_JOB_TTL_SECONDS_AFTER_FINISHED, - IMAGE=f"{ml_infra_config().docker_repo_prefix}/{repo}:{tag}", + IMAGE=f"{infra_config().docker_repo_prefix}/{repo}:{tag}", COMMAND=command, CPUS=str(resource_requests.cpus), MEMORY=str(resource_requests.memory), STORAGE_DICT=storage_dict, MOUNT_PATH=mount_path, - INPUT_LOCATION="--input-local", # TODO when we enable mounting remote s3files should be "--input-remote" + INPUT_LOCATION="--input-local", + # TODO when we enable mounting remote s3files should be "--input-remote" S3_FILE="unused", LOCAL_FILE_NAME=mount_location, FILE_CONTENTS_B64ENCODED=job_config_b64encoded, - AWS_ROLE=ml_infra_config().profile_ml_inference_worker, + AWS_ROLE=infra_config().profile_ml_inference_worker, + REQUEST_ID=LoggerTagManager.get(LoggerTagKey.REQUEST_ID) or "", + BATCH_JOB_NUM_WORKERS=num_workers or 1, ) resource_spec = load_k8s_yaml(resource_key, substitution_kwargs) @@ -227,6 +287,13 @@ def _generate_job_spec( resource_spec["spec"]["template"]["spec"]["containers"][0]["env"] = _add_list_values( container_env_list, override_envs ) + if "annotations" in resource_spec["metadata"]: + resource_spec["metadata"]["annotations"].update(annotations) + else: + resource_spec["metadata"]["annotations"] = annotations + # add trigger_id label if job was spawned by trigger + if "trigger_id" in labels: + resource_spec["metadata"]["labels"]["trigger_id"] = labels["trigger_id"] return job_id, resource_spec async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[DockerImageBatchJob]: @@ -239,7 +306,7 @@ async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[Docker try: jobs = await batch_client.list_namespaced_job( namespace=hmi_config.endpoint_namespace, - label_selector=f"{LLM_ENGINE_JOB_ID_LABEL_SELECTOR}={batch_job_id}", + label_selector=f"{LAUNCH_JOB_ID_LABEL_SELECTOR}={batch_job_id}", ) if len(jobs.items) == 0: logger.info(f"Job id {batch_job_id} not found") @@ -251,9 +318,21 @@ async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[Docker logger.exception("Got an exception when trying to read the Job") raise EndpointResourceInfraException from exc + core_client = get_kubernetes_core_client() + try: + pods = await core_client.list_namespaced_pod( + namespace=hmi_config.endpoint_namespace, + label_selector=f"{LAUNCH_JOB_ID_LABEL_SELECTOR}={batch_job_id}", + ) + except ApiException as exc: + logger.exception("Got an exception when trying to read pods for the Job") + raise EndpointResourceInfraException from exc + # This pod list isn't always needed, but it's simpler code-wise to always make the request + job_labels = job.metadata.labels + annotations = job.metadata.annotations - status = self._parse_job_status_from_k8s_obj(job) + status = _parse_job_status_from_k8s_obj(job, pods.items) return DockerImageBatchJob( id=batch_job_id, @@ -262,6 +341,8 @@ async def get_docker_image_batch_job(self, batch_job_id: str) -> Optional[Docker created_at=job.metadata.creation_timestamp, completed_at=job.status.completion_time, status=status, + annotations=annotations, + num_workers=job.spec.completions, ) async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatchJob]: @@ -276,14 +357,31 @@ async def list_docker_image_batch_jobs(self, owner: str) -> List[DockerImageBatc logger.exception("Got an exception when trying to list the Jobs") raise EndpointResourceInfraException from exc + core_client = get_kubernetes_core_client() + try: + pods = await core_client.list_namespaced_pod( + namespace=hmi_config.endpoint_namespace, + label_selector=f"{OWNER_LABEL_SELECTOR}={owner},job-name", # get only pods associated with a job + ) + except ApiException as exc: + logger.exception("Got an exception when trying to read pods for the Job") + raise EndpointResourceInfraException from exc + + # Join jobs + pods + pods_per_job = make_job_id_to_pods_mapping(pods.items) + return [ DockerImageBatchJob( - id=job.metadata.labels.get(LLM_ENGINE_JOB_ID_LABEL_SELECTOR), + id=job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR), created_by=job.metadata.labels.get("created_by"), owner=owner, created_at=job.metadata.creation_timestamp, completed_at=job.status.completion_time, - status=self._parse_job_status_from_k8s_obj(job), + annotations=job.metadata.annotations, + status=_parse_job_status_from_k8s_obj( + job, pods_per_job[job.metadata.labels.get(LAUNCH_JOB_ID_LABEL_SELECTOR)] + ), + num_workers=job.spec.completions, ) for job in jobs.items ] @@ -321,17 +419,3 @@ async def _delete_docker_image_batch_job(self, batch_job_id: str) -> bool: ) raise EndpointResourceInfraException from exc return True - - @staticmethod - def _parse_job_status_from_k8s_obj(job: V1Job) -> BatchJobStatus: - status = job.status - # these counts are the number of pods in some given status - if status.failed is not None and status.failed > 0: - return BatchJobStatus.FAILURE - if status.succeeded is not None and status.succeeded > 0: - return BatchJobStatus.SUCCESS - if status.ready is not None and status.ready > 0: - return BatchJobStatus.RUNNING # empirically this doesn't happen - if status.active is not None and status.active > 0: - return BatchJobStatus.RUNNING # TODO this might be a mix of pending and running - return BatchJobStatus.PENDING diff --git a/server/llm_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py b/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py similarity index 88% rename from server/llm_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py index d4e282ad..7294e533 100644 --- a/server/llm_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_model_endpoint_infra_gateway.py @@ -1,13 +1,13 @@ import os from typing import Any, Dict, List, Optional -from llm_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest -from llm_engine_server.common.settings import ( +from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest +from model_engine_server.common.settings import ( RESTRICTED_ENDPOINT_LABELS, generate_deployment_name, get_service_builder_queue, ) -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, GpuType, @@ -15,14 +15,16 @@ ModelEndpointRecord, StorageSpecificationType, ) -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.domain.gateways import TaskQueueGateway -from llm_engine_server.infra.gateways.model_endpoint_infra_gateway import ModelEndpointInfraGateway -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.gateways import TaskQueueGateway +from model_engine_server.infra.gateways.model_endpoint_infra_gateway import ( + ModelEndpointInfraGateway, +) +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) -BUILD_TASK_NAME = "llm_engine_server.service_builder.tasks_v1.build_endpoint" +BUILD_TASK_NAME = "model_engine_server.service_builder.tasks_v1.build_endpoint" SERVICE_IDENTIFIER = os.getenv("SERVICE_IDENTIFIER") @@ -59,7 +61,8 @@ def create_model_endpoint_infra( gpus: int, memory: StorageSpecificationType, gpu_type: Optional[GpuType], - storage: Optional[StorageSpecificationType], + storage: StorageSpecificationType, + nodes_per_worker: int, optimize_costs: bool, aws_role: str, results_s3_bucket: str, @@ -68,6 +71,7 @@ def create_model_endpoint_infra( labels: Dict[str, str], prewarm: bool, high_priority: Optional[bool], + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str], default_callback_auth: Optional[CallbackAuth], ) -> str: @@ -85,6 +89,7 @@ def create_model_endpoint_infra( memory=memory, gpu_type=gpu_type, storage=storage, + nodes_per_worker=nodes_per_worker, optimize_costs=optimize_costs, aws_role=aws_role, results_s3_bucket=results_s3_bucket, @@ -93,6 +98,7 @@ def create_model_endpoint_infra( labels=labels, prewarm=prewarm, high_priority=high_priority, + billing_tags=billing_tags, default_callback_url=default_callback_url, default_callback_auth=default_callback_auth, ) @@ -122,6 +128,7 @@ async def update_model_endpoint_infra( labels: Optional[Dict[str, str]] = None, prewarm: Optional[bool] = None, high_priority: Optional[bool] = None, + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth] = None, ) -> str: @@ -146,6 +153,8 @@ async def update_model_endpoint_infra( gpu_type = infra_state.resource_state.gpu_type if storage is None: storage = infra_state.resource_state.storage + # Don't allow changing nodes_per_worker + nodes_per_worker = infra_state.resource_state.nodes_per_worker if optimize_costs is None: optimize_costs = infra_state.resource_state.optimize_costs or False if child_fn_info is None: @@ -159,6 +168,8 @@ async def update_model_endpoint_infra( infra_state.labels.update(labels) labels = infra_state.labels assert labels is not None + if billing_tags is None and endpoint_config is not None: + billing_tags = endpoint_config.billing_tags redact_restricted_labels(labels) if prewarm is None: if infra_state.prewarm is None: @@ -192,6 +203,7 @@ async def update_model_endpoint_infra( memory=memory, gpu_type=gpu_type, storage=storage, + nodes_per_worker=nodes_per_worker, optimize_costs=optimize_costs, aws_role=aws_role, results_s3_bucket=results_s3_bucket, @@ -200,6 +212,7 @@ async def update_model_endpoint_infra( labels=labels, prewarm=prewarm, high_priority=high_priority, + billing_tags=billing_tags, default_callback_url=default_callback_url, default_callback_auth=default_callback_auth, ) diff --git a/server/llm_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py b/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py similarity index 75% rename from server/llm_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py rename to model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py index 6ebff349..883335bf 100644 --- a/server/llm_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/live_model_endpoints_schema_gateway.py @@ -1,11 +1,13 @@ import json from enum import Enum -from typing import Any, Callable, Dict, Sequence, Set, Type, Union +from typing import Any, Callable, Dict, List, Sequence, Set, Type, Union +import pydantic from fastapi import routing -from fastapi.openapi.utils import get_openapi_path -from fastapi.utils import get_model_definitions -from llm_engine_server.common.dtos.tasks import ( +from fastapi._compat import GenerateJsonSchema, get_definitions +from fastapi.openapi.constants import REF_TEMPLATE +from fastapi.openapi.utils import get_fields_from_routes, get_openapi_path +from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, GetAsyncTaskV1Response, RequestSchema, @@ -13,8 +15,8 @@ SyncEndpointPredictV1Response, TaskStatus, ) -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.domain.entities import ( +from model_engine_server.core.config import infra_config +from model_engine_server.domain.entities import ( CallbackAuth, CallbackBasicAuth, CallbackmTLSAuth, @@ -22,16 +24,14 @@ ModelEndpointsSchema, ModelEndpointType, ) -from llm_engine_server.domain.gateways import ModelEndpointsSchemaGateway -from pydantic import BaseModel +from model_engine_server.domain.gateways import ModelEndpointsSchemaGateway +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway from starlette.routing import BaseRoute -from . import FilesystemGateway - # Caches the default model definition so we don't need to recompute every time _default_model_definitions = None -API_REFERENCE_TITLE = "LLMEngine Endpoints API Reference" +API_REFERENCE_TITLE = "Launch Endpoints API Reference" API_REFERENCE_VERSION = "1.0.0" @@ -39,9 +39,7 @@ def predict_stub_async(payload: EndpointPredictV1Request) -> GetAsyncTaskV1Respo raise NotImplementedError -def predict_stub_sync( - payload: EndpointPredictV1Request, -) -> SyncEndpointPredictV1Response: +def predict_stub_sync(payload: EndpointPredictV1Request) -> SyncEndpointPredictV1Response: raise NotImplementedError @@ -59,7 +57,7 @@ def get_model_endpoints_schema( model_endpoint_names = [] model_definitions = {} for record in model_endpoint_records: - response_model: Type[BaseModel] = GetAsyncTaskV1Response + response_model: Type[pydantic.BaseModel] = GetAsyncTaskV1Response predict_stub: Callable[[EndpointPredictV1Request], Any] = predict_stub_async base_route = "/v1/async-tasks" if record.endpoint_type == ModelEndpointType.SYNC: @@ -74,6 +72,7 @@ def get_model_endpoints_schema( methods=["POST"], ) routes.append(route) + definitions = self.get_schemas_from_model_endpoint_record(record) definitions = LiveModelEndpointsSchemaGateway.update_model_definitions_with_prefix( prefix=record.name, model_definitions=definitions @@ -122,10 +121,20 @@ def get_openapi( if isinstance(route, routing.APIRoute): prefix = model_endpoint_name model_name_map = LiveModelEndpointsSchemaGateway.get_model_name_map(prefix) + schema_generator = GenerateJsonSchema(ref_template=REF_TEMPLATE) + all_fields = get_fields_from_routes([route]) + field_mapping, _ = get_definitions( + fields=all_fields, + schema_generator=schema_generator, + model_name_map=model_name_map, + ) + result = get_openapi_path( route=route, - model_name_map=model_name_map, operation_ids=operation_ids, + schema_generator=schema_generator, + model_name_map=model_name_map, + field_mapping=field_mapping, ) if result: path, security_schemes, path_definitions = result @@ -155,19 +164,17 @@ def update_model_definitions_with_prefix( Returns: Dict[str, Any]: The updated model definitions. """ - models: Set[Union[Type[BaseModel], Type[Enum]]] = { - CallbackAuth, - CallbackBasicAuth, - CallbackmTLSAuth, - TaskStatus, + models: List[Type[pydantic.BaseModel]] = [ EndpointPredictV1Request, GetAsyncTaskV1Response, SyncEndpointPredictV1Response, - } - definitions = get_model_definitions( - flat_models=models, - model_name_map=LiveModelEndpointsSchemaGateway.get_model_name_map(prefix), + ] + + model_name_map = LiveModelEndpointsSchemaGateway.get_model_name_map(prefix) + definitions: Dict[str, Any] = LiveModelEndpointsSchemaGateway.get_model_definitions( + models=models, model_name_map=model_name_map ) + user_definitions = {} for k, v in model_definitions.items(): LiveModelEndpointsSchemaGateway.update_schema_refs_with_prefix(v, prefix) @@ -191,9 +198,7 @@ def update_schema_refs_with_prefix(schema: Dict[str, Any], prefix: str) -> None: LiveModelEndpointsSchemaGateway.update_schema_refs_with_prefix(item, prefix) @staticmethod - def get_model_name_map( - prefix: str, - ) -> Dict[Union[Type[BaseModel], Type[Enum]], str]: + def get_model_name_map(prefix: str) -> Dict[Union[Type[pydantic.BaseModel], Type[Enum]], str]: return { CallbackAuth: "CallbackAuth", CallbackBasicAuth: "CallbackBasicAuth", @@ -223,9 +228,7 @@ def get_schemas_from_model_endpoint_record( try: if schema_location is not None: with self.filesystem_gateway.open( - schema_location, - "rb", - aws_profile=ml_infra_config().profile_ml_worker, + schema_location, "rb", aws_profile=infra_config().profile_ml_worker ) as f: schema = json.load(f) finally: @@ -239,8 +242,8 @@ def get_default_model_definitions() -> Dict[str, Any]: global _default_model_definitions if _default_model_definitions is None: - _default_model_definitions = get_model_definitions( - flat_models={RequestSchema, ResponseSchema}, + _default_model_definitions = LiveModelEndpointsSchemaGateway.get_model_definitions( + models=[RequestSchema, ResponseSchema], model_name_map={ RequestSchema: "RequestSchema", ResponseSchema: "ResponseSchema", @@ -248,3 +251,21 @@ def get_default_model_definitions() -> Dict[str, Any]: ) return _default_model_definitions + + @staticmethod + def get_model_definitions( + models: Sequence[Type[pydantic.BaseModel]], + model_name_map: Dict[Union[Type[pydantic.BaseModel], Type[Enum]], str], + ) -> Dict[str, Any]: + """Get OpenAPI definitions for provided models using the name provided in model_name_map""" + + definitions = {} + for model in models: + schema = model.model_json_schema( + schema_generator=GenerateJsonSchema, ref_template=REF_TEMPLATE + ) + m_defs = schema.pop("$defs", {}) + definitions.update(m_defs) + model_name = model_name_map.get(model, model.__name__) + definitions[model_name] = schema + return definitions diff --git a/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py new file mode 100644 index 00000000..97f2b83a --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py @@ -0,0 +1,261 @@ +from typing import Any, AsyncIterable, Dict, Optional + +import aiohttp +import orjson +import requests +import sseclient +from model_engine_server.common.aiohttp_sse_client import EventSource +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.tasks import ( + SyncEndpointPredictV1Request, + SyncEndpointPredictV1Response, + TaskStatus, +) +from model_engine_server.common.env_vars import CIRCLECI, LOCAL +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ( + InvalidRequestException, + NoHealthyUpstreamException, + TooManyRequestsException, + UpstreamServiceError, +) +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway +from model_engine_server.domain.gateways.streaming_model_endpoint_inference_gateway import ( + StreamingModelEndpointInferenceGateway, +) +from model_engine_server.infra.gateways.dns_resolver import resolve_dns +from model_engine_server.infra.gateways.k8s_resource_parser import get_node_port +from orjson import JSONDecodeError +from tenacity import ( + AsyncRetrying, + RetryError, + retry_if_exception_type, + stop_after_attempt, + stop_after_delay, + stop_any, + wait_exponential, +) + +logger = make_logger(logger_name()) + +SYNC_ENDPOINT_RETRIES = 8 # Must be an integer >= 0 +SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS = 10 +SYNC_ENDPOINT_MAX_RETRY_WAIT = 5 +SYNC_ENDPOINT_EXP_BACKOFF_BASE = ( + 1.2 # Must be a float > 1.0, lower number means more retries but less time waiting. +) + + +def _get_streaming_endpoint_url( + service_name: str, path: str = "/stream", manually_resolve_dns: bool = False +) -> str: + if CIRCLECI: + # Circle CI: a NodePort is used to expose the service + # The IP address is obtained from `minikube ip`. + protocol: str = "http" + hostname: str = f"192.168.49.2:{get_node_port(service_name)}" + elif LOCAL: + # local development: the svc.cluster.local address is only available w/in the k8s cluster + protocol = "https" + hostname = f"{service_name}.{infra_config().dns_host_domain}" + elif manually_resolve_dns: + protocol = "http" + hostname = resolve_dns( + f"{service_name}.{hmi_config.endpoint_namespace}.svc.cluster.local", port=protocol + ) + else: + protocol = "http" + # no need to hit external DNS resolution if we're w/in the k8s cluster + hostname = f"{service_name}.{hmi_config.endpoint_namespace}.svc.cluster.local" + return f"{protocol}://{hostname}{path}" + + +def _serialize_json(data) -> str: + # Use orjson, which is faster and more correct than native Python json library. + # This is more important for sync endpoints, which are more latency-sensitive. + return orjson.dumps(data).decode() + + +class LiveStreamingModelEndpointInferenceGateway(StreamingModelEndpointInferenceGateway): + """ + Concrete implementation for an StreamingModelEndpointInferenceGateway. + + make_single_request() makes the streaming inference request to the endpoint + make_request_with_retries() wraps make_single_request() with retries + streaming_predict() wraps make_request_with_retries() and yields SyncEndpointPredictV1Response + """ + + def __init__(self, monitoring_metrics_gateway: MonitoringMetricsGateway, use_asyncio: bool): + self.monitoring_metrics_gateway = monitoring_metrics_gateway + self.use_asyncio = use_asyncio + + async def make_single_request(self, request_url: str, payload_json: Dict[str, Any]): + errored = False + if self.use_asyncio: + async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient: + aio_resp = await aioclient.post( + request_url, + json=payload_json, + headers={"Content-Type": "application/json"}, + ) + status = aio_resp.status + if status == 200: + async with EventSource(response=aio_resp) as event_source: + async for event in event_source: + yield event.data + else: + content = await aio_resp.read() + errored = True + else: + resp = requests.post( + request_url, + json=payload_json, + headers={"Content-Type": "application/json"}, + stream=True, + ) + client = sseclient.SSEClient(resp) + status = resp.status_code + if status == 200: + for event in client.events(): + yield event.data + else: + content = resp.content + errored = True + + # Need to have these exceptions raised outside the async context so that + # tenacity can properly capture them. + if errored: + if status == 429: + raise TooManyRequestsException("429 returned") + if status == 503: + raise NoHealthyUpstreamException("503 returned") + else: + raise UpstreamServiceError(status_code=status, content=content) + + async def make_request_with_retries( + self, + request_url: str, + payload_json: Dict[str, Any], + timeout_seconds: float, + num_retries: int, + endpoint_name: str, + ) -> AsyncIterable[Dict[str, Any]]: + # Copied from document-endpoint + # More details at https://tenacity.readthedocs.io/en/latest/#retrying-code-block + # Try/catch + for loop makes us retry only when we get a 429 from the synchronous endpoint. + # We should be creating a new requests Session each time, which should avoid sending + # requests to the same endpoint. This is admittedly a hack until we get proper + # least-outstanding-requests load balancing to our http endpoints + + try: + async for attempt in AsyncRetrying( + stop=stop_any( + stop_after_attempt(num_retries + 1), + stop_after_delay(timeout_seconds), + ), + retry=retry_if_exception_type( + ( + TooManyRequestsException, + NoHealthyUpstreamException, + aiohttp.ClientConnectorError, + ) + ), + wait=wait_exponential( + multiplier=1, + min=1, + max=SYNC_ENDPOINT_MAX_RETRY_WAIT, + exp_base=SYNC_ENDPOINT_EXP_BACKOFF_BASE, + ), + ): + with attempt: + if attempt.retry_state.attempt_number > 1: + logger.info( + f"Retry number {attempt.retry_state.attempt_number}" + ) # pragma: no cover + response = self.make_single_request(request_url, payload_json) + async for item in response: + yield orjson.loads(item) + return + except RetryError as e: + if isinstance(e.last_attempt.exception(), TooManyRequestsException): + logger.warning("Hit max # of retries, returning 429 to client") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 429) + raise UpstreamServiceError(status_code=429, content=b"Too many concurrent requests") + elif isinstance(e.last_attempt.exception(), NoHealthyUpstreamException): + logger.warning("Pods didn't spin up in time, returning 503 to client") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 503) + raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") + elif isinstance(e.last_attempt.exception(), aiohttp.ClientConnectorError): + logger.warning("ClientConnectorError, returning 503 to client") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 503) + raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") + else: + logger.error("Unknown Exception Type") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 500) + raise UpstreamServiceError(status_code=500, content=b"Unknown error") + except JSONDecodeError: + logger.exception("JSONDecodeError") + raise UpstreamServiceError(status_code=500, content=b"JSONDecodeError") + + # Never reached because tenacity should throw a RetryError if we exit the for loop. + # This is for mypy. + # pragma: no cover + raise Exception("Should never reach this line") + + async def streaming_predict( + self, + topic: str, + predict_request: SyncEndpointPredictV1Request, + manually_resolve_dns: bool = False, + endpoint_name: Optional[str] = None, + ) -> AsyncIterable[SyncEndpointPredictV1Response]: + deployment_url = _get_streaming_endpoint_url( + topic, + path=predict_request.destination_path or "/stream", + manually_resolve_dns=manually_resolve_dns, + ) + + try: + timeout_seconds = ( + SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS + if predict_request.timeout_seconds is None + else predict_request.timeout_seconds + ) + num_retries = ( + SYNC_ENDPOINT_RETRIES + if predict_request.num_retries is None + else predict_request.num_retries + ) + response = self.make_request_with_retries( + request_url=deployment_url, + payload_json=predict_request.model_dump(exclude_none=True), + timeout_seconds=timeout_seconds, + num_retries=num_retries, + endpoint_name=endpoint_name or topic, + ) + async for item in response: + yield SyncEndpointPredictV1Response(status=TaskStatus.SUCCESS, result=item) + except UpstreamServiceError as exc: + logger.error(f"Service error on streaming task: {exc.content!r}") + + if exc.status_code == 400: + error_json = orjson.loads(exc.content.decode("utf-8")) + if "result" in error_json: + error_json = orjson.loads(error_json["result"]) + raise InvalidRequestException(error_json) + + try: + error_json = orjson.loads(exc.content.decode("utf-8")) + result_traceback = ( + error_json.get("detail", {}).get("traceback") + if isinstance(error_json, dict) + else None + ) + except JSONDecodeError: + result_traceback = exc.content.decode() + + yield SyncEndpointPredictV1Response( + status=TaskStatus.FAILURE, + traceback=result_traceback, + ) diff --git a/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py new file mode 100644 index 00000000..3e083fb6 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py @@ -0,0 +1,249 @@ +from typing import Any, Dict, Optional + +import aiohttp +import orjson +import requests +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.tasks import ( + SyncEndpointPredictV1Request, + SyncEndpointPredictV1Response, + TaskStatus, +) +from model_engine_server.common.env_vars import CIRCLECI, LOCAL +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ( + InvalidRequestException, + NoHealthyUpstreamException, + TooManyRequestsException, + UpstreamServiceError, +) +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway +from model_engine_server.domain.gateways.sync_model_endpoint_inference_gateway import ( + SyncModelEndpointInferenceGateway, +) +from model_engine_server.infra.gateways.dns_resolver import resolve_dns +from model_engine_server.infra.gateways.k8s_resource_parser import get_node_port +from tenacity import ( + AsyncRetrying, + RetryError, + retry_if_exception_type, + stop_after_attempt, + stop_after_delay, + stop_any, + wait_exponential, +) + +logger = make_logger(logger_name()) + +SYNC_ENDPOINT_RETRIES = 8 # Must be an integer >= 0 +SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS = 10 +SYNC_ENDPOINT_MAX_RETRY_WAIT = 5 +SYNC_ENDPOINT_EXP_BACKOFF_BASE = ( + 1.2 # Must be a float > 1.0, lower number means more retries but less time waiting. +) + + +def _get_sync_endpoint_url( + service_name: str, destination_path: str = "/predict", manually_resolve_dns: bool = False +) -> str: + if CIRCLECI: + # Circle CI: a NodePort is used to expose the service + # The IP address is obtained from `minikube ip`. + protocol: str = "http" + hostname: str = f"192.168.49.2:{get_node_port(service_name)}" + elif LOCAL: + # local development: the svc.cluster.local address is only available w/in the k8s cluster + protocol = "https" + hostname = f"{service_name}.{infra_config().dns_host_domain}" + elif manually_resolve_dns: + protocol = "http" + hostname = resolve_dns( + f"{service_name}.{hmi_config.endpoint_namespace}.svc.cluster.local", port=protocol + ) + else: + protocol = "http" + # no need to hit external DNS resolution if we're w/in the k8s cluster + hostname = f"{service_name}.{hmi_config.endpoint_namespace}.svc.cluster.local" + return f"{protocol}://{hostname}{destination_path}" + + +def _serialize_json(data) -> str: + # Use orjson, which is faster and more correct than native Python json library. + # This is more important for sync endpoints, which are more latency-sensitive. + return orjson.dumps(data).decode() + + +class LiveSyncModelEndpointInferenceGateway(SyncModelEndpointInferenceGateway): + """ + Concrete implementation for an SyncModelEndpointInferenceGateway. + """ + + def __init__(self, monitoring_metrics_gateway: MonitoringMetricsGateway, use_asyncio: bool): + self.monitoring_metrics_gateway = monitoring_metrics_gateway + self.use_asyncio = use_asyncio + + async def make_single_request(self, request_url: str, payload_json: Dict[str, Any]): + if self.use_asyncio: + async with aiohttp.ClientSession(json_serialize=_serialize_json) as client: + aio_resp = await client.post( + request_url, + json=payload_json, + headers={"Content-Type": "application/json"}, + ) + status = aio_resp.status + if status == 200: + return await aio_resp.json() + content = await aio_resp.read() + else: + resp = requests.post( + request_url, + json=payload_json, + headers={"Content-Type": "application/json"}, + ) + status = resp.status_code + if status == 200: + return resp.json() + content = resp.content + + # Need to have these exceptions raised outside the async context so that + # tenacity can properly capture them. + if status == 429: + raise TooManyRequestsException("429 returned") + if status == 503: + raise NoHealthyUpstreamException("503 returned") + else: + raise UpstreamServiceError(status_code=status, content=content) + + async def make_request_with_retries( + self, + request_url: str, + payload_json: Dict[str, Any], + timeout_seconds: float, + num_retries: int, + endpoint_name: str, + ) -> Dict[str, Any]: + # Copied from document-endpoint + # More details at https://tenacity.readthedocs.io/en/latest/#retrying-code-block + # Try/catch + for loop makes us retry only when we get a 429 from the synchronous endpoint. + # We should be creating a new requests Session each time, which should avoid sending + # requests to the same endpoint. This is admittedly a hack until we get proper + # least-outstanding-requests load balancing to our http endpoints + + try: + async for attempt in AsyncRetrying( + stop=stop_any( + stop_after_attempt(num_retries + 1), + stop_after_delay(timeout_seconds), + ), + retry=retry_if_exception_type( + ( + TooManyRequestsException, + NoHealthyUpstreamException, + aiohttp.ClientConnectorError, + ) + ), + wait=wait_exponential( + multiplier=1, + min=1, + max=SYNC_ENDPOINT_MAX_RETRY_WAIT, + exp_base=SYNC_ENDPOINT_EXP_BACKOFF_BASE, + ), + ): + with attempt: + if attempt.retry_state.attempt_number > 1: # pragma: no cover + logger.info(f"Retry number {attempt.retry_state.attempt_number}") + return await self.make_single_request(request_url, payload_json) + except RetryError as e: + if isinstance(e.last_attempt.exception(), TooManyRequestsException): + logger.warning("Hit max # of retries, returning 429 to client") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 429) + raise UpstreamServiceError(status_code=429, content=b"Too many concurrent requests") + elif isinstance(e.last_attempt.exception(), NoHealthyUpstreamException): + logger.warning("Pods didn't spin up in time, returning 503 to client") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 503) + raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") + elif isinstance(e.last_attempt.exception(), aiohttp.ClientConnectorError): + logger.warning("ClientConnectorError, returning 503 to client") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 503) + raise UpstreamServiceError(status_code=503, content=b"No healthy upstream") + else: + logger.error("Unknown Exception Type") + self.monitoring_metrics_gateway.emit_http_call_error_metrics(endpoint_name, 500) + raise UpstreamServiceError(status_code=500, content=b"Unknown error") + + # Never reached because tenacity should throw a RetryError if we exit the for loop. + # This is for mypy. + # pragma: no cover + return {} + + async def predict( + self, + topic: str, + predict_request: SyncEndpointPredictV1Request, + manually_resolve_dns: bool = False, + endpoint_name: Optional[str] = None, + ) -> SyncEndpointPredictV1Response: + deployment_url = _get_sync_endpoint_url( + topic, + destination_path=predict_request.destination_path or "/predict", + manually_resolve_dns=manually_resolve_dns, + ) + + try: + timeout_seconds = ( + SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS + if predict_request.timeout_seconds is None + else predict_request.timeout_seconds + ) + num_retries = ( + SYNC_ENDPOINT_RETRIES + if predict_request.num_retries is None + else predict_request.num_retries + ) + response = await self.make_request_with_retries( + request_url=deployment_url, + payload_json=predict_request.model_dump(exclude_none=True), + timeout_seconds=timeout_seconds, + num_retries=num_retries, + endpoint_name=endpoint_name or topic, + ) + except UpstreamServiceError as exc: + logger.error(f"Service error on sync task: {exc.content!r}") + + if exc.status_code == 400: + error_json = orjson.loads(exc.content.decode("utf-8")) + if "result" in error_json: + error_json = orjson.loads(error_json["result"]) + + raise InvalidRequestException(error_json) + + try: + # Try to parse traceback from the response, fallback to just return all the content if failed. + # Three cases considered: + # detail.traceback + # result."detail.traceback" + # result."detail[]" + error_json = orjson.loads(exc.content.decode("utf-8")) + if "result" in error_json: + error_json = orjson.loads(error_json["result"]) + + detail = error_json.get("detail", {}) + if not isinstance(detail, dict): + result_traceback = orjson.dumps(error_json) + else: + result_traceback = error_json.get("detail", {}).get( + "traceback", "Failed to parse traceback" + ) + return SyncEndpointPredictV1Response( + status=TaskStatus.FAILURE, + traceback=result_traceback, + ) + + except Exception as e: + logger.error(f"Failed to parse error: {e}") + return SyncEndpointPredictV1Response( + status=TaskStatus.FAILURE, traceback=exc.content.decode() + ) + + return SyncEndpointPredictV1Response(status=TaskStatus.SUCCESS, result=response) diff --git a/server/llm_engine_server/infra/gateways/model_endpoint_infra_gateway.py b/model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py similarity index 95% rename from server/llm_engine_server/infra/gateways/model_endpoint_infra_gateway.py rename to model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py index 51ee3c13..a61a890f 100644 --- a/server/llm_engine_server/infra/gateways/model_endpoint_infra_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/model_endpoint_infra_gateway.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from typing import Any, Dict, List, Optional -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, GpuType, @@ -29,7 +29,8 @@ def create_model_endpoint_infra( gpus: int, memory: StorageSpecificationType, gpu_type: Optional[GpuType], - storage: Optional[StorageSpecificationType], + storage: StorageSpecificationType, + nodes_per_worker: int, optimize_costs: bool, aws_role: str, results_s3_bucket: str, @@ -38,6 +39,7 @@ def create_model_endpoint_infra( labels: Dict[str, str], prewarm: bool, high_priority: Optional[bool], + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str], default_callback_auth: Optional[CallbackAuth], ) -> str: @@ -65,6 +67,7 @@ def create_model_endpoint_infra( to False high_priority: Makes all pods for this endpoint higher priority to enable faster pod spinup time. Higher priority pods will displace the lower priority dummy pods from shared pool. + billing_tags: Arbitrary tags passed to billing default_callback_url: The default callback URL to use for the model endpoint. default_callback_auth: The default callback auth to use for the model endpoint. @@ -91,6 +94,7 @@ async def update_model_endpoint_infra( labels: Optional[Dict[str, str]] = None, prewarm: Optional[bool] = None, high_priority: Optional[bool] = None, + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth], ) -> str: @@ -98,6 +102,7 @@ async def update_model_endpoint_infra( Updates the underlying infrastructure for a Model Endpoint. Args: + billing_tags: Arbitrary tags passed to billing model_endpoint_record: The associated record of a model endpoint. min_workers: The minimum number of workers for the model endpoint. max_workers: The maximum number of workers for the model endpoint. diff --git a/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py b/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py new file mode 100644 index 00000000..027442e8 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py @@ -0,0 +1,55 @@ +from typing import Optional + +import aioredis +from model_engine_server.domain.gateways.inference_autoscaling_metrics_gateway import ( + InferenceAutoscalingMetricsGateway, +) + +EXPIRY_SECONDS = 60 # 1 minute; this gets added to the cooldown time present in the keda ScaledObject to get total +# scaledown time. This also needs to be larger than the keda ScaledObject's refresh rate. +PREWARM_EXPIRY_SECONDS = 60 * 60 # 1 hour + + +class RedisInferenceAutoscalingMetricsGateway(InferenceAutoscalingMetricsGateway): + def __init__( + self, redis_info: Optional[str] = None, redis_client: Optional[aioredis.Redis] = None + ): + assert redis_info or redis_client, "Either redis_info or redis_client must be defined." + if redis_info: + # If aioredis cannot create a connection pool, reraise that as an error because the + # default error message is cryptic and not obvious. + try: + self._redis = aioredis.from_url(redis_info, health_check_interval=60) + except Exception as exc: + raise RuntimeError( + "If redis_info is specified, RedisInferenceAutoscalingMetricsGateway must be" + "initialized within a coroutine. Please specify the redis_client directly." + ) from exc + else: + assert redis_client is not None # for mypy + self._redis = redis_client + + @staticmethod + def _find_redis_key(endpoint_id: str): + # Keep in line with keda scaled object yaml + return f"launch-endpoint-autoscaling:{endpoint_id}" + + async def _emit_metric(self, endpoint_id: str, expiry_time: int): + key = self._find_redis_key(endpoint_id) + await self._redis.expire(key, expiry_time) # does nothing if key doesn't exist, + # but avoids a race condition where the key expires in between the lpush and subsequent expire commands + await self._redis.lpush(key, 1) # we only care about the length of the list, not the values + await self._redis.ltrim(key, 0, 0) # we only want to scale from 0 to 1 for redis + await self._redis.expire(key, expiry_time) + + async def emit_inference_autoscaling_metric(self, endpoint_id: str): + await self._emit_metric(endpoint_id, EXPIRY_SECONDS) + + async def emit_prewarm_metric(self, endpoint_id: str): + await self._emit_metric(endpoint_id, PREWARM_EXPIRY_SECONDS) + + async def create_or_update_resources(self, endpoint_id: str): + pass # no extra resources needed + + async def delete_resources(self, endpoint_id: str): + pass diff --git a/model-engine/model_engine_server/infra/gateways/resources/__init__.py b/model-engine/model_engine_server/infra/gateways/resources/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py new file mode 100644 index 00000000..3799ed65 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py @@ -0,0 +1,68 @@ +import os +from typing import Any, Dict + +from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from azure.servicebus.management import ServiceBusAdministrationClient +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, + QueueInfo, +) + +logger = make_logger(logger_name()) + + +def _get_servicebus_administration_client() -> ServiceBusAdministrationClient: + return ServiceBusAdministrationClient( + f"{os.getenv('SERVICEBUS_NAMESPACE')}.servicebus.windows.net", + credential=DefaultAzureCredential(), + ) + + +class ASBQueueEndpointResourceDelegate(QueueEndpointResourceDelegate): + """ + Using Azure Service Bus. + """ + + async def create_queue_if_not_exists( + self, + endpoint_id: str, + endpoint_name: str, + endpoint_created_by: str, + endpoint_labels: Dict[str, Any], + ) -> QueueInfo: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + with _get_servicebus_administration_client() as client: + try: + client.create_queue(queue_name=queue_name) + except ResourceExistsError: + pass + + return QueueInfo(queue_name, None) + + async def delete_queue(self, endpoint_id: str) -> None: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + with _get_servicebus_administration_client() as client: + try: + client.delete_queue(queue_name=queue_name) + except ResourceNotFoundError: + logger.info(f"Could not find ASB queue {queue_name} for endpoint {endpoint_id}") + + async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + with _get_servicebus_administration_client() as client: + try: + queue_attributes = client.get_queue_runtime_properties(queue_name=queue_name) + except ResourceNotFoundError as e: + raise EndpointResourceInfraException( + f"Could not find ASB queue {queue_name} for endpoint {endpoint_id}" + ) from e + + # queue_attributes does have other fields, but we don't need them right now + return { + "name": queue_attributes.name, + "total_message_count": queue_attributes.total_message_count, + "active_message_count": queue_attributes.active_message_count, + } diff --git a/server/llm_engine_server/infra/gateways/resources/endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py similarity index 91% rename from server/llm_engine_server/infra/gateways/resources/endpoint_resource_gateway.py rename to model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py index c09f1247..8c2779b3 100644 --- a/server/llm_engine_server/infra/gateways/resources/endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/endpoint_resource_gateway.py @@ -1,18 +1,17 @@ from abc import ABC, abstractmethod from typing import Dict, Generic, Sequence, Tuple, TypeVar -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.pydantic_types import BaseModel +from model_engine_server.domain.entities import ( ModelEndpointInfraState, ModelEndpointRecord, ModelEndpointType, ) -from pydantic import BaseModel +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import QueueInfo __all__: Sequence[str] = ( "EndpointResourceGateway", - "QueueInfo", "EndpointResourceGatewayCreateOrUpdateResourcesResponse", ) @@ -21,11 +20,6 @@ class EndpointResourceGatewayCreateOrUpdateResourcesResponse(BaseModel): destination: str -class QueueInfo(BaseModel): - queue_name: str - broker: BrokerType - - Q = TypeVar("Q", bound=QueueInfo) """Either a QueueInfo or some specialization of it. """ diff --git a/server/llm_engine_server/infra/gateways/resources/fake_sqs_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py similarity index 58% rename from server/llm_engine_server/infra/gateways/resources/fake_sqs_endpoint_resource_delegate.py rename to model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py index 1c9ad4a5..9ded2d6e 100644 --- a/server/llm_engine_server/infra/gateways/resources/fake_sqs_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/fake_queue_endpoint_resource_delegate.py @@ -1,32 +1,31 @@ from typing import Any, Dict, Sequence -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, - SQSQueueInfo, +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, + QueueInfo, ) -from mypy_boto3_sqs.type_defs import GetQueueAttributesResultTypeDef -__all__: Sequence[str] = ("FakeSQSEndpointResourceDelegate",) +__all__: Sequence[str] = ("FakeQueueEndpointResourceDelegate",) -class FakeSQSEndpointResourceDelegate(SQSEndpointResourceDelegate): +class FakeQueueEndpointResourceDelegate(QueueEndpointResourceDelegate): async def create_queue_if_not_exists( self, endpoint_id: str, endpoint_name: str, endpoint_created_by: str, endpoint_labels: Dict[str, Any], - ) -> SQSQueueInfo: - queue_name = SQSEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + ) -> QueueInfo: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) queue_url = f"http://foobar.com/{queue_name}" - return SQSQueueInfo(queue_name, queue_url) + return QueueInfo(queue_name, queue_url) async def delete_queue(self, endpoint_id: str) -> None: # Don't need to do anything, since the contract says that deleting is a no-op, # and we don't need to simulate real exceptions. pass - async def get_queue_attributes(self, endpoint_id: str) -> GetQueueAttributesResultTypeDef: + async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]: return { "Attributes": { "ApproximateNumberOfMessages": "100", diff --git a/server/llm_engine_server/infra/gateways/resources/image_cache_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py similarity index 75% rename from server/llm_engine_server/infra/gateways/resources/image_cache_gateway.py rename to model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py index fcb3be26..84af84bd 100644 --- a/server/llm_engine_server/infra/gateways/resources/image_cache_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/image_cache_gateway.py @@ -1,16 +1,22 @@ -import hashlib import os from typing import Any, Dict, List, TypedDict, cast from kubernetes_asyncio.client.rest import ApiException -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( - get_kubernetes_apps_client, +from model_engine_server.common.config import hmi_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( + get_kubernetes_apps_client, # If you ever add more imports here, update test_image_cache_gateway accordingly, otherwise you will likely mangle live cluster resources +) +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( + k8s_yaml_exists, load_k8s_yaml, ) -from llm_engine_server.infra.gateways.resources.k8s_resource_types import ImageCacheArguments +from model_engine_server.infra.gateways.resources.k8s_resource_types import ( + ImageCacheArguments, + compute_image_hash, +) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class CachedImages(TypedDict): @@ -18,10 +24,9 @@ class CachedImages(TypedDict): a10: List[str] a100: List[str] t4: List[str] - - -KUBERNETES_MAX_LENGTH = 64 -LLM_ENGINE_DEFAULT_NAMESPACE = "llm-engine" + h100: List[str] + h100_3g40gb: List[str] + h100_1g20gb: List[str] class ImageCacheGateway: @@ -35,27 +40,29 @@ async def create_or_update_image_cache(self, cached_images: CachedImages) -> Non base_path = os.getenv("WORKSPACE") if base_path is None: raise EnvironmentError("WORKSPACE env variable not found") - base_name = "llm-engine-image-cache" + base_name = "launch-image-cache" for compute_type, images in cached_images.items(): # Required for mypy TypedDict compute_type = cast(str, compute_type) + compute_type = compute_type.replace("_", "-") # for k8s valid name images = cast(list, images) name = f"{base_name}-{compute_type}" substitution_kwargs = ImageCacheArguments( RESOURCE_NAME=name, - NAMESPACE=LLM_ENGINE_DEFAULT_NAMESPACE, + NAMESPACE=hmi_config.endpoint_namespace, ) resource_key = f"image-cache-{compute_type}.yaml" + if not k8s_yaml_exists(resource_key): + logger.info(f"Didn't find yaml for {compute_type}, skipping") + continue image_cache = load_k8s_yaml(resource_key, substitution_kwargs) labels = image_cache["spec"]["template"]["metadata"]["labels"] containers = image_cache["spec"]["template"]["spec"]["containers"] for image in images: - image_hash = str(hashlib.md5(str(image).encode()).hexdigest())[ - :KUBERNETES_MAX_LENGTH - ] + image_hash = compute_image_hash(image) labels[image_hash] = "True" base_container_dict = { @@ -92,7 +99,7 @@ async def _create_image_cache( try: await apps_api.create_namespaced_daemon_set( - namespace=LLM_ENGINE_DEFAULT_NAMESPACE, + namespace=hmi_config.endpoint_namespace, body=image_cache, ) logger.info(f"Created image cache daemonset {name}") @@ -100,7 +107,7 @@ async def _create_image_cache( if exc.status == 409: # Do not update existing daemonset if the cache is unchanged existing_daemonsets = await apps_api.list_namespaced_daemon_set( - namespace=LLM_ENGINE_DEFAULT_NAMESPACE + namespace=hmi_config.endpoint_namespace ) for daemonset in existing_daemonsets.items: if daemonset.metadata.name == name: @@ -116,7 +123,7 @@ async def _create_image_cache( f"Image cache daemonset {name} already exists, replacing with new values" ) await apps_api.replace_namespaced_daemon_set( - name=name, namespace=LLM_ENGINE_DEFAULT_NAMESPACE, body=image_cache + name=name, namespace=hmi_config.endpoint_namespace, body=image_cache ) elif exc.status == 404: logger.exception("ImageCache API not found. Is the ImageCache CRD installed?") diff --git a/server/llm_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py similarity index 54% rename from server/llm_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py rename to model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py index 7ec3d19d..24eb4335 100644 --- a/server/llm_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_endpoint_resource_delegate.py @@ -15,34 +15,34 @@ ) from kubernetes_asyncio.client.rest import ApiException from kubernetes_asyncio.config import ConfigException -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.common.env_vars import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.env_vars import ( CIRCLECI, - LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH, - LLM_ENGINE_SERVICE_TEMPLATE_FOLDER, + LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH, + LAUNCH_SERVICE_TEMPLATE_FOLDER, ) -from llm_engine_server.common.serialization_utils import b64_to_python_json, str_to_bool -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ( +from model_engine_server.common.serialization_utils import b64_to_python_json, str_to_bool +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.entities import ( ModelEndpointConfig, ModelEndpointDeploymentState, ModelEndpointInfraState, + ModelEndpointRecord, ModelEndpointResourceState, ModelEndpointType, ModelEndpointUserConfigState, - RunnableImageFlavor, RunnableImageLike, - StreamingEnhancedRunnableImageFlavor, TritonEnhancedRunnableImageFlavor, ) -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.infra.gateways.k8s_resource_parser import ( +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.use_cases.model_endpoint_use_cases import MODEL_BUNDLE_CHANGED_KEY +from model_engine_server.infra.gateways.k8s_resource_parser import ( get_per_worker_value_from_target_concurrency, ) -from llm_engine_server.infra.gateways.resources.k8s_resource_types import ( - LLM_ENGINE_HIGH_PRIORITY_CLASS, +from model_engine_server.infra.gateways.resources.k8s_resource_types import ( + LAUNCH_HIGH_PRIORITY_CLASS, CommonEndpointParams, HorizontalAutoscalingEndpointParams, ResourceArguments, @@ -50,9 +50,9 @@ get_endpoint_resource_arguments_from_request, ) from packaging import version -from pydantic.utils import deep_update +from pydantic.v1.utils import deep_update -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) HTTP_PORT = 5000 @@ -60,19 +60,25 @@ # and where the user actually owns the files BASE_PATH_IN_ENDPOINT = "/app" -DATADOG_ENV_VAR = { - "DATADOG_TRACE_ENABLED", - "DD_SERVICE", - "DD_ENV", - "DD_VERSION", - "DD_AGENT_HOST", +DATADOG_ENV_VAR = {"DD_TRACE_ENABLED", "DD_SERVICE", "DD_ENV", "DD_VERSION", "DD_AGENT_HOST"} +LWS_DEFAULT_ENV_VAR = { + "K8S_OWN_POD_NAME", + "K8S_OWN_NAMESPACE", + "K8S_LWS_NAME", + "K8S_LWS_CLUSTER_SIZE", } +# These two should match the values present in `service_template_config_map.yaml` +# for the container names in the LWS template. +LWS_LEADER_CONTAINER_NAME = "lws-leader" +LWS_WORKER_CONTAINER_NAME = "lws-worker" + _lazy_load_kubernetes_clients = True _kubernetes_apps_api = None _kubernetes_core_api = None _kubernetes_autoscaling_api = None _kubernetes_batch_api = None +_kubernetes_policy_api = None _kubernetes_custom_objects_api = None _kubernetes_cluster_version = None @@ -152,6 +158,16 @@ def get_kubernetes_batch_client(): # pragma: no cover return _kubernetes_batch_api +def get_kubernetes_policy_client(): # pragma: no cover + if _lazy_load_kubernetes_clients: + global _kubernetes_policy_api + else: + _kubernetes_policy_api = None + if not _kubernetes_policy_api: + _kubernetes_policy_api = kubernetes_asyncio.client.PolicyV1Api() + return _kubernetes_policy_api + + def get_kubernetes_custom_objects_client(): # pragma: no cover if _lazy_load_kubernetes_clients: global _kubernetes_custom_objects_api @@ -163,11 +179,11 @@ def get_kubernetes_custom_objects_client(): # pragma: no cover def _endpoint_id_to_k8s_resource_group_name(endpoint_id: str) -> str: - return f"llm-engine-endpoint-id-{endpoint_id}".replace("_", "-") + return f"launch-endpoint-id-{endpoint_id}".replace("_", "-") def _k8s_resource_group_name_to_endpoint_id(k8s_resource_group_name: str) -> str: - return k8s_resource_group_name.replace("llm-engine-endpoint-id-", "").replace("-", "_") + return k8s_resource_group_name.replace("launch-endpoint-id-", "").replace("-", "_") _kube_config_loaded = False @@ -189,12 +205,21 @@ async def maybe_load_kube_config(): _kube_config_loaded = True +def k8s_yaml_exists(key: str) -> bool: + if LAUNCH_SERVICE_TEMPLATE_FOLDER is not None: + return os.path.exists(os.path.join(LAUNCH_SERVICE_TEMPLATE_FOLDER, key)) + else: + with open(LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH, "r") as f: + config_map_str = yaml.safe_load(f.read()) + return key in config_map_str["data"] + + def load_k8s_yaml(key: str, substitution_kwargs: ResourceArguments) -> Dict[str, Any]: - if LLM_ENGINE_SERVICE_TEMPLATE_FOLDER is not None: - with open(os.path.join(LLM_ENGINE_SERVICE_TEMPLATE_FOLDER, key), "r") as f: + if LAUNCH_SERVICE_TEMPLATE_FOLDER is not None: + with open(os.path.join(LAUNCH_SERVICE_TEMPLATE_FOLDER, key), "r") as f: template_str = f.read() else: - with open(LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH, "r") as f: + with open(LAUNCH_SERVICE_TEMPLATE_CONFIG_MAP_PATH, "r") as f: config_map_str = yaml.safe_load(f.read()) template_str = config_map_str["data"][key] @@ -220,8 +245,39 @@ def get_main_container_from_deployment_template(deployment_template: Dict[str, A return user_container -def add_datadog_env_to_main_container(deployment_template: Dict[str, Any]) -> None: - user_container = get_main_container_from_deployment_template(deployment_template) +def get_leader_container_from_lws_template(lws_template: Dict[str, Any]): + containers = lws_template["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"][ + "containers" + ] + for container in containers: + if container["name"] == LWS_LEADER_CONTAINER_NAME: + leader_container = container + break + else: + raise ValueError( + f"leader container (container['name'] == '{LWS_LEADER_CONTAINER_NAME}') not found in lws template when adding datadog env to leader container." + ) + return leader_container + + +def get_worker_container_from_lws_template(lws_template: Dict[str, Any]): + containers = lws_template["spec"]["leaderWorkerTemplate"]["workerTemplate"]["spec"][ + "containers" + ] + for container in containers: + if container["name"] == LWS_WORKER_CONTAINER_NAME: + worker_container = container + break + else: + raise ValueError( + f"worker container (container['name'] == '{LWS_WORKER_CONTAINER_NAME}') not found in lws template when adding datadog env to worker container." + ) + return worker_container + + +def add_datadog_env_to_container( + deployment_template: Dict[str, Any], user_container: Dict[str, Any] +) -> None: user_container_envs = [] for env in user_container["env"]: @@ -231,7 +287,7 @@ def add_datadog_env_to_main_container(deployment_template: Dict[str, Any]) -> No user_container_envs.extend( [ { - "name": "DATADOG_TRACE_ENABLED", + "name": "DD_TRACE_ENABLED", "value": "false" if CIRCLECI else "true", }, { @@ -256,16 +312,51 @@ def add_datadog_env_to_main_container(deployment_template: Dict[str, Any]) -> No user_container["env"] = user_container_envs +def add_lws_default_env_vars_to_container(container: Dict[str, Any]) -> None: + container_envs = [] + container_envs.extend( + [ + {"name": "K8S_OWN_POD_NAME", "valueFrom": {"fieldRef": {"fieldPath": "metadata.name"}}}, + { + "name": "K8S_OWN_NAMESPACE", + "valueFrom": {"fieldRef": {"fieldPath": "metadata.namespace"}}, + }, + { + "name": "K8S_LWS_NAME", + "valueFrom": { + "fieldRef": {"fieldPath": "metadata.labels['leaderworkerset.sigs.k8s.io/name']"} + }, + }, + { + "name": "K8S_LWS_CLUSTER_SIZE", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.annotations['leaderworkerset.sigs.k8s.io/size']" + } + }, + }, + ] + ) + + for env in container["env"]: + if env["name"] not in LWS_DEFAULT_ENV_VAR: + container_envs.append(env) + container["env"] = container_envs + + class K8SEndpointResourceDelegate: async def create_or_update_resources( self, request: CreateOrUpdateResourcesRequest, sqs_queue_name: Optional[str] = None, sqs_queue_url: Optional[str] = None, - ) -> None: + ) -> str: + """ + Returns a "destination", i.e. the name of the service/sqs queue to send tasks to the endpoint + """ await maybe_load_kube_config() try: - await self._create_or_update_resources( + return await self._create_or_update_resources( request=request, sqs_queue_name=sqs_queue_name, sqs_queue_url=sqs_queue_url, @@ -324,6 +415,18 @@ def _get_env_value_from_envlist( return envvar.value return None + @staticmethod + def _get_env_value_from_envlist_for_custom_object( + envlist: Optional[List[Dict]], name: str + ): # pragma: no cover + # Custom objects client returns nested Dicts, not objects. + if envlist is None: + return None + for envvar in envlist: + if envvar["name"] == name: + return envvar["value"] + return None + def _get_common_endpoint_params(self, deployment_config: V1Deployment) -> CommonEndpointParams: """ Reads some values from k8s common to both sync and async endpoints @@ -334,7 +437,7 @@ def _get_common_endpoint_params(self, deployment_config: V1Deployment) -> Common Dictionary with detected values """ main_container = self._get_main_container(deployment_config) - llm_engine_container = self._get_llm_engine_container(deployment_config) + launch_container = self._get_launch_container(deployment_config) resources = main_container.resources image = main_container.image @@ -343,7 +446,7 @@ def _get_common_endpoint_params(self, deployment_config: V1Deployment) -> Common gpus = int((resources.limits or dict()).get("nvidia.com/gpu", 0)) storage = resources.requests.get("ephemeral-storage") - envlist = llm_engine_container.env + envlist = launch_container.env # Hack: for LIRA since the bundle_url isn't really a real env var # we use the `image` for now. This may change if we allow for unpickling # in LIRA. @@ -354,9 +457,9 @@ def _get_common_endpoint_params(self, deployment_config: V1Deployment) -> Common # Temporary fix: new LIRA endpoints created should have these env vars # but old ones don't, so we can fetch them from the config. if aws_role is None: - aws_role = ml_infra_config().profile_ml_inference_worker + aws_role = infra_config().profile_ml_inference_worker if results_s3_bucket is None: - results_s3_bucket = ml_infra_config().s3_bucket + results_s3_bucket = infra_config().s3_bucket if bundle_url is None or aws_role is None or results_s3_bucket is None: raise ValueError("Failed to fetch common endpoint values.") @@ -386,6 +489,67 @@ def _get_common_endpoint_params(self, deployment_config: V1Deployment) -> Common ) return common_build_endpoint_request + def _get_common_endpoint_params_for_lws_type(self, lws_config: Any) -> CommonEndpointParams: + main_container = self._get_main_leader_container_from_lws(lws_config) + launch_container = self._get_launch_container_from_lws(lws_config) + + resources = main_container["resources"] + image = main_container["image"] + + cpus = resources["requests"]["cpu"] + memory = resources["requests"]["memory"] + gpus = int((resources["limits"] or dict()).get("nvidia.com/gpu", 0)) + storage = resources["requests"].get("ephemeral-storage") + + envlist = launch_container["env"] + # There really isn't a bundle_url for LWS since those use RunnableImages + bundle_url = ( + self._get_env_value_from_envlist_for_custom_object(envlist, "BUNDLE_URL") or image + ) + aws_role = self._get_env_value_from_envlist_for_custom_object(envlist, "AWS_PROFILE") + results_s3_bucket = self._get_env_value_from_envlist_for_custom_object( + envlist, "RESULTS_S3_BUCKET" + ) + + # AWS_PROFILE and RESULTS_S3_BUCKET should always be set, but if not present + # we can fetch them from the config. + if aws_role is None: + aws_role = infra_config().profile_ml_inference_worker + if results_s3_bucket is None: + results_s3_bucket = infra_config().s3_bucket + + if bundle_url is None or aws_role is None or results_s3_bucket is None: + raise ValueError("Failed to fetch common endpoint values.") + + try: + node_selector = lws_config["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"][ + "nodeSelector" + ] + gpu_type = node_selector.get("k8s.amazonaws.com/accelerator", None) + except KeyError: + gpu_type = None + + try: + labels = lws_config["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["metadata"][ + "labels" + ] + except KeyError: + labels = None + + common_build_endpoint_request: CommonEndpointParams = dict( + cpus=cpus, + memory=memory, + gpus=gpus, + gpu_type=gpu_type, + storage=storage, + bundle_url=bundle_url, + aws_role=aws_role, + results_s3_bucket=results_s3_bucket, + image=image, + labels=labels, + ) + return common_build_endpoint_request + @staticmethod def _get_main_container(deployment_config: V1Deployment) -> V1Container: pod_containers = deployment_config.spec.template.spec.containers @@ -395,7 +559,7 @@ def _get_main_container(deployment_config: V1Deployment) -> V1Container: return name_to_container["main"] @staticmethod - def _get_llm_engine_container(deployment_config: V1Deployment) -> V1Container: + def _get_launch_container(deployment_config: V1Deployment) -> V1Container: pod_containers = deployment_config.spec.template.spec.containers name_to_container = {container.name: container for container in pod_containers} @@ -412,10 +576,87 @@ def _get_llm_engine_container(deployment_config: V1Deployment) -> V1Container: raise ValueError("No main container detected") return name_to_container["main"] + @staticmethod + def _get_main_leader_container_from_lws(lws_config: Any): + """ + Similar to _get_main_container, this returns a nested dict. + """ + leader_containers = lws_config["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"][ + "containers" + ] + name_to_container = {container["name"]: container for container in leader_containers} + if LWS_LEADER_CONTAINER_NAME not in name_to_container: + raise ValueError("No main leader container detected") + return name_to_container[LWS_LEADER_CONTAINER_NAME] + + @staticmethod + def _get_launch_container_from_lws(lws_config: Any): + leader_containers = lws_config["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"][ + "containers" + ] + name_to_container = {container["name"]: container for container in leader_containers} + # If a celery forwarder is present, use that + if "celery-forwarder" in name_to_container: + return name_to_container["celery-forwarder"] + + # If a http forwarder is present, use that + if "http-forwarder" in name_to_container: + return name_to_container["http-forwarder"] + + # Don't need backwards compatibility here + raise ValueError("No forwarder container detected") + # --- Private low level fns that interact with k8s @staticmethod - async def _create_deployment(deployment: Dict[str, Any], name: str) -> None: + async def _create_lws( + lws: Dict[str, Any], + name: str, + ) -> None: + """ + Lower-level function to create/replace a LWS + Args: + lws: LWS body (a nested Dict in format specified by Kubernetes) + name: The name of the LWS on k8s + Returns: + Nothing: raises k8s APIException if failure + """ + custom_objects_api = get_kubernetes_custom_objects_client() + try: + await custom_objects_api.create_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + namespace=hmi_config.endpoint_namespace, + plural="leaderworkersets", + body=lws, + ) + except ApiException as exc: + if exc.status == 409: + logger.info(f"LeaderWorkerSet {name} already exists, replacing") + existing_lws = await custom_objects_api.get_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + namespace=hmi_config.endpoint_namespace, + plural="leaderworkersets", + name=name, + ) + new_lws = deep_update(existing_lws, lws) + await custom_objects_api.replace_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + namespace=hmi_config.endpoint_namespace, + plural="leaderworkersets", + name=name, + body=new_lws, + ) + else: + logger.exception("Got an exception when trying to apply the LeaderWorkerSet") + raise + + @staticmethod + async def _create_deployment( + model_endpoint_record: ModelEndpointRecord, deployment: Dict[str, Any], name: str + ) -> None: """ Lower-level function to create/patch a k8s deployment Args: @@ -436,32 +677,49 @@ async def _create_deployment(deployment: Dict[str, Any], name: str) -> None: ) except ApiException as exc: if exc.status == 409: - logger.info(f"Deployment {name} already exists, patching") - - if "replicas" in deployment["spec"]: - # Don't pass in replicas if we're doing an update, because we want to just - # let the autoscaler do its thing. - del deployment["spec"]["replicas"] - - logger.info(f"Deployment {name} contents: {deployment}") - - try: - await apps_client.patch_namespaced_deployment( + if ( + model_endpoint_record.metadata is not None + and MODEL_BUNDLE_CHANGED_KEY in model_endpoint_record.metadata + ): + logger.info( + f"Deployment {name} already exists, replacing since model bundle has changed" + ) + logger.info(f"Deployment {name} contents: {deployment}") + await apps_client.replace_namespaced_deployment( name=name, namespace=hmi_config.endpoint_namespace, body=deployment, ) - except ApiException as exc2: - if exc2.status in [409, 422]: - logger.info(f"Deployment {name} failed to patch, falling back to replacing") - await apps_client.replace_namespaced_deployment( + else: + logger.info(f"Deployment {name} already exists, patching") + + if "replicas" in deployment["spec"]: + # Don't pass in replicas if we're doing an update, because we want to just + # let the autoscaler do its thing. + del deployment["spec"]["replicas"] + logger.info(f"Deployment {name} contents: {deployment}") + + try: + await apps_client.patch_namespaced_deployment( name=name, namespace=hmi_config.endpoint_namespace, body=deployment, ) - else: - logger.exception("Got an exception when trying to patch the Deployment") - raise + except ApiException as exc2: + if exc2.status in [409, 422]: + logger.info( + f"Deployment {name} failed to patch, falling back to replacing" + ) + await apps_client.replace_namespaced_deployment( + name=name, + namespace=hmi_config.endpoint_namespace, + body=deployment, + ) + else: + logger.exception( + "Got an exception when trying to replace the Deployment" + ) + raise else: logger.exception("Got an exception when trying to apply the Deployment") raise @@ -585,6 +843,85 @@ async def _create_vpa(vpa: Dict[str, Any], name: str) -> None: logger.exception("Got an exception when trying to apply the VerticalPodAutoscaler") raise + @staticmethod + async def _create_pdb(pdb: Dict[str, Any], name: str) -> None: + """ + Lower-level function to create/patch a k8s PodDisruptionBudget (pdb) + Args: + pdb: PDB body (a nested Dict in the format specified by Kubernetes) + name: The name of the pdb on K8s + + Returns: + Nothing; raises a k8s ApiException if failure + + """ + policy_api = get_kubernetes_policy_client() + try: + await policy_api.create_namespaced_pod_disruption_budget( + namespace=hmi_config.endpoint_namespace, + body=pdb, + ) + except ApiException as exc: + if exc.status == 409: + logger.info(f"PodDisruptionBudget {name} already exists, replacing") + + existing_pdb = await policy_api.read_namespaced_pod_disruption_budget( + name=name, namespace=hmi_config.endpoint_namespace + ) + replace_pdb = pdb.copy() + if "metadata" not in replace_pdb: + replace_pdb["metadata"] = {} + replace_pdb["metadata"]["resourceVersion"] = existing_pdb.metadata.resource_version + + await policy_api.replace_namespaced_pod_disruption_budget( + name=name, + namespace=hmi_config.endpoint_namespace, + body=replace_pdb, + ) + else: + logger.exception("Got an exception when trying to apply the PodDisruptionBudget") + raise + + @staticmethod + async def _create_keda_scaled_object(scaled_object: Dict[str, Any], name: str) -> None: + custom_objects_api = get_kubernetes_custom_objects_client() + try: + await custom_objects_api.create_namespaced_custom_object( + group="keda.sh", + version="v1alpha1", + namespace=hmi_config.endpoint_namespace, + plural="scaledobjects", + body=scaled_object, + ) + except ApiException as exc: + if exc.status == 409: + logger.info(f"ScaledObject {name} already exists, replacing") + + # The async k8s client has a bug with patching custom objects, so we manually + # merge the new ScaledObject with the old one and then replace the old one with the merged + # one. See _create_vpa for more details. + # There is a setting `restoreToOriginalReplicaCount` in the keda ScaledObject that should be set to + # false which should make it safe to do this replace (as opposed to a patch) + existing_scaled_object = await custom_objects_api.get_namespaced_custom_object( + group="keda.sh", + version="v1alpha1", + namespace=hmi_config.endpoint_namespace, + plural="scaledobjects", + name=name, + ) + new_scaled_object = deep_update(existing_scaled_object, scaled_object) + await custom_objects_api.replace_namespaced_custom_object( + group="keda.sh", + version="v1alpha1", + namespace=hmi_config.endpoint_namespace, + plural="scaledobjects", + name=name, + body=new_scaled_object, + ) + else: + logger.exception("Got an exception when trying to apply the ScaledObject") + raise + @staticmethod async def _create_destination_rule(destination_rule: Dict[str, Any], name: str) -> None: """ @@ -679,6 +1016,46 @@ async def _create_virtual_service(virtual_service: Dict[str, Any], name: str) -> logger.exception("Got an exception when trying to apply the VirtualService") raise + @staticmethod + async def _create_lws_service_entry(lws_service_entry: Dict[str, Any], name: str) -> None: + # Note: this istio ServiceEntry is specific to the LWS case, + # as it is used to enable the "hack" where we manually resolve + # the IP of a K8s service and route to the IP directly. + custom_objects_api = get_kubernetes_custom_objects_client() + try: + await custom_objects_api.create_namespaced_custom_object( + group="networking.istio.io", + version="v1beta1", + namespace=hmi_config.endpoint_namespace, + plural="serviceentries", + body=lws_service_entry, + ) + except ApiException as exc: + if exc.status == 409: + logger.info(f"ServiceEntry {name} already exists, replacing") + # The async k8s client has a bug with patching custom objects, so we manually + # merge the new ServiceEntry with the old one and then replace the old one with the merged + # one. + existing_service_entry = await custom_objects_api.get_namespaced_custom_object( + group="networking.istio.io", + version="v1beta1", + namespace=hmi_config.endpoint_namespace, + plural="serviceentries", + name=name, + ) + new_service_entry = deep_update(existing_service_entry, lws_service_entry) + await custom_objects_api.replace_namespaced_custom_object( + group="networking.istio.io", + version="v1beta1", + namespace=hmi_config.endpoint_namespace, + plural="serviceentries", + name=name, + body=new_service_entry, + ) + else: + logger.exception("Got an exception when trying to apply the ServiceEntry") + raise + @staticmethod async def _create_service(service, name: str) -> None: """ @@ -715,7 +1092,7 @@ async def _get_config_maps( ) -> List[kubernetes_asyncio.client.models.v1_config_map.V1ConfigMap]: """ Gets ConfigMaps associated with a given user id + endpoint name - This should be considered the same abstraction level as get_deployment + This should be considered the same abstraction level as _get_deployment """ k8s_core_api = get_kubernetes_core_client() @@ -739,9 +1116,37 @@ async def _get_config_maps( return config_maps.items @staticmethod - async def _get_all_config_maps() -> List[ - kubernetes_asyncio.client.models.v1_config_map.V1ConfigMap - ]: + async def _get_deployment(endpoint_id, deployment_name): + """ + Gets the Deployment associated with a given endpoint_id + deployment name + Handles a legacy fallback case as well, where Deployments were named differently. + + """ + apps_client = get_kubernetes_apps_client() + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + try: + deployment_config = await apps_client.read_namespaced_deployment( + name=k8s_resource_group_name, namespace=hmi_config.endpoint_namespace + ) + except ApiException as e: + if e.status == 404: + logger.warning( + f"Could not find resource, falling back to legacy deployment_name: " + f"{k8s_resource_group_name=}, {endpoint_id=}, {deployment_name=}" + ) + k8s_resource_group_name = deployment_name + deployment_config = await apps_client.read_namespaced_deployment( + name=k8s_resource_group_name, + namespace=hmi_config.endpoint_namespace, + ) + else: + raise + return deployment_config + + @staticmethod + async def _get_all_config_maps() -> ( + List[kubernetes_asyncio.client.models.v1_config_map.V1ConfigMap] + ): k8s_core_api = get_kubernetes_core_client() config_maps = await k8s_core_api.list_namespaced_config_map( namespace=hmi_config.endpoint_namespace @@ -785,6 +1190,28 @@ def _translate_k8s_config_maps_to_user_config_data( endpoint_config=endpoint_config, ) + @staticmethod + async def _delete_lws(endpoint_id: str) -> bool: + custom_objects_client = get_kubernetes_custom_objects_client() + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + try: + await custom_objects_client.delete_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + namespace=hmi_config.endpoint_namespace, + plural="leaderworkersets", + name=k8s_resource_group_name, + ) + except ApiException as e: + if e.status == 404: + logger.warning( + f"Trying to delete nonexistent LeaderWorkerSet {k8s_resource_group_name}" + ) + else: + logger.exception(f"Deletion of LeaderWorkerSet {k8s_resource_group_name} failed") + return False + return True + @staticmethod async def _delete_deployment(endpoint_id: str, deployment_name: str) -> bool: apps_client = get_kubernetes_apps_client() @@ -846,8 +1273,8 @@ async def _delete_config_maps(self, endpoint_id: str, deployment_name: str) -> b @staticmethod async def _delete_service(endpoint_id: str, deployment_name: str) -> bool: - core_client = get_kubernetes_core_client() k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + core_client = get_kubernetes_core_client() try: await core_client.delete_namespaced_service( name=k8s_resource_group_name, namespace=hmi_config.endpoint_namespace @@ -877,6 +1304,22 @@ async def _delete_service(endpoint_id: str, deployment_name: str) -> bool: return False return True + @staticmethod + async def _delete_lws_service(endpoint_id: str, deployment_name: str): + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + lws_service_name = K8SEndpointResourceDelegate._get_lws_service_resource_name( + k8s_resource_group_name + ) + core_client = get_kubernetes_core_client() + try: + await core_client.delete_namespaced_service( + name=lws_service_name, namespace=hmi_config.endpoint_namespace + ) + except ApiException: + logger.exception(f"Deletion of Service {lws_service_name} failed") + return False + return True + @staticmethod async def _delete_destination_rule(endpoint_id: str) -> bool: custom_objects_client = get_kubernetes_custom_objects_client() @@ -922,13 +1365,35 @@ async def _delete_virtual_service(endpoint_id: str) -> bool: return True @staticmethod - async def _delete_vpa(endpoint_id: str) -> bool: + async def _delete_lws_service_entry(endpoint_id: str) -> bool: custom_objects_client = get_kubernetes_custom_objects_client() k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) try: await custom_objects_client.delete_namespaced_custom_object( - group="autoscaling.k8s.io", - version="v1", + group="networking.istio.io", + version="v1beta1", + namespace=hmi_config.endpoint_namespace, + plural="serviceentries", + name=k8s_resource_group_name, + ) + except ApiException as e: + if e.status == 404: + logger.warning( + f"Trying to delete nonexistent ServiceEntry {k8s_resource_group_name}" + ) + else: + logger.exception(f"Deletion of ServiceEntry {k8s_resource_group_name} failed") + return False + return True + + @staticmethod + async def _delete_vpa(endpoint_id: str) -> bool: + custom_objects_client = get_kubernetes_custom_objects_client() + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + try: + await custom_objects_client.delete_namespaced_custom_object( + group="autoscaling.k8s.io", + version="v1", namespace=hmi_config.endpoint_namespace, plural="verticalpodautoscalers", name=k8s_resource_group_name, @@ -981,6 +1446,49 @@ async def _delete_hpa(endpoint_id: str, deployment_name: str) -> bool: return False return True + @staticmethod + async def _delete_pdb(endpoint_id: str) -> bool: + policy_client = get_kubernetes_policy_client() + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + try: + await policy_client.delete_namespaced_pod_disruption_budget( + namespace=hmi_config.endpoint_namespace, + name=k8s_resource_group_name, + ) + except ApiException as e: + if e.status == 404: + logger.warning( + f"Trying to delete nonexistent PodDisruptionBudget {k8s_resource_group_name}" + ) + else: + logger.exception( + f"Deletion of PodDisruptionBudget {k8s_resource_group_name} failed" + ) + return False + return True + + @staticmethod + async def _delete_keda_scaled_object(endpoint_id: str) -> bool: + custom_objects_client = get_kubernetes_custom_objects_client() + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + try: + await custom_objects_client.delete_namespaced_custom_object( + group="keda.sh", + version="v1alpha1", + namespace=hmi_config.endpoint_namespace, + plural="scaledobjects", + name=k8s_resource_group_name, + ) + except ApiException as e: + if e.status == 404: + logger.warning( + f"Trying to delete nonexistent ScaledObject {k8s_resource_group_name}" + ) + else: + logger.exception(f"Deletion of ScaledObject {k8s_resource_group_name} failed") + return False + return True + # --- Private higher level fns that interact with k8s @staticmethod @@ -989,12 +1497,10 @@ def _get_deployment_resource_name(request: CreateOrUpdateResourcesRequest) -> st model_endpoint_record = build_endpoint_request.model_endpoint_record flavor = model_endpoint_record.current_model_bundle.flavor - if isinstance(flavor, (RunnableImageFlavor, StreamingEnhancedRunnableImageFlavor)): - flavor_class = "runnable-image" - elif isinstance(flavor, TritonEnhancedRunnableImageFlavor): + if isinstance(flavor, TritonEnhancedRunnableImageFlavor): flavor_class = "triton-enhanced-runnable-image" else: - flavor_class = "artifact" + flavor_class = "runnable-image" mode = model_endpoint_record.endpoint_type.value device = "gpu" if build_endpoint_request.gpus > 0 else "cpu" @@ -1002,12 +1508,45 @@ def _get_deployment_resource_name(request: CreateOrUpdateResourcesRequest) -> st deployment_resource_name = f"deployment-{flavor_class}-{mode}-{device}" return deployment_resource_name + @staticmethod + def _get_lws_resource_name(request: CreateOrUpdateResourcesRequest) -> str: + build_endpoint_request = request.build_endpoint_request + model_endpoint_record = build_endpoint_request.model_endpoint_record + flavor = model_endpoint_record.current_model_bundle.flavor + if isinstance(flavor, TritonEnhancedRunnableImageFlavor): + flavor_class = "triton-enhanced-runnable-image" + else: + flavor_class = "runnable-image" + if flavor_class == "triton-enhanced-runnable-image": + raise ValueError("LWS is not supported for Triton Enhanced Runnable Image") + # flavor not being triton-enhanced should already be checked in the endpoint create on the gateway + # but check again just in case + # Gateway should also guard against cloudpickle or zip being passed in here + + mode = model_endpoint_record.endpoint_type.value + device = "gpu" if build_endpoint_request.gpus > 0 else "cpu" + if mode not in ["streaming"]: + raise ValueError("LWS is not supported for async or sync endpoints") + if device not in ["gpu"]: + raise ValueError("LWS is not supported for CPU endpoints") + + lws_resource_name = f"leader-worker-set-{mode}-{device}" + return lws_resource_name + + @staticmethod + def _get_lws_service_resource_name(k8s_resource_group_name: str): + return f"{k8s_resource_group_name}-leader" + async def _create_or_update_resources( self, request: CreateOrUpdateResourcesRequest, sqs_queue_name: Optional[str] = None, sqs_queue_url: Optional[str] = None, - ) -> None: + ) -> str: + """ + Returns a "destination", which is how to address the endpoint, either through + sqs or through a k8s service. + """ sqs_queue_name_str = sqs_queue_name or "" sqs_queue_url_str = sqs_queue_url or "" build_endpoint_request = request.build_endpoint_request @@ -1015,28 +1554,56 @@ async def _create_or_update_resources( k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name( build_endpoint_request.model_endpoint_record.id ) + is_multinode = build_endpoint_request.nodes_per_worker > 1 - deployment_resource_name = self._get_deployment_resource_name(request) - deployment_arguments = get_endpoint_resource_arguments_from_request( - k8s_resource_group_name=k8s_resource_group_name, - request=request, - sqs_queue_name=sqs_queue_name_str, - sqs_queue_url=sqs_queue_url_str, - endpoint_resource_name=deployment_resource_name, - ) - deployment_template = load_k8s_yaml( - f"{deployment_resource_name}.yaml", deployment_arguments - ) - if isinstance( - request.build_endpoint_request.model_endpoint_record.current_model_bundle.flavor, - RunnableImageLike, - ): - add_datadog_env_to_main_container(deployment_template) - await self._create_deployment( - deployment=deployment_template, - name=k8s_resource_group_name, - ) + # Create LWS/Deployment + if is_multinode: + lws_resource_name = self._get_lws_resource_name(request) + lws_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name=lws_resource_name, + ) + lws_template = load_k8s_yaml(f"{lws_resource_name}.yaml", lws_arguments) + leader_template = get_leader_container_from_lws_template(lws_template) + worker_template = get_worker_container_from_lws_template(lws_template) + add_lws_default_env_vars_to_container(leader_template) + add_lws_default_env_vars_to_container(worker_template) + add_datadog_env_to_container(lws_template, leader_template) + add_datadog_env_to_container(lws_template, worker_template) + await self._create_lws( + lws=lws_template, + name=k8s_resource_group_name, + ) + k8s_service_name = self._get_lws_service_resource_name(k8s_resource_group_name) + else: + deployment_resource_name = self._get_deployment_resource_name(request) + deployment_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name=deployment_resource_name, + ) + deployment_template = load_k8s_yaml( + f"{deployment_resource_name}.yaml", deployment_arguments + ) + if isinstance( + request.build_endpoint_request.model_endpoint_record.current_model_bundle.flavor, + RunnableImageLike, + ): + user_container = get_main_container_from_deployment_template(deployment_template) + add_datadog_env_to_container(deployment_template, user_container) + await self._create_deployment( + model_endpoint_record=request.build_endpoint_request.model_endpoint_record, + deployment=deployment_template, + name=k8s_resource_group_name, + ) + k8s_service_name = k8s_resource_group_name + # Create ConfigMaps user_config_arguments = get_endpoint_resource_arguments_from_request( k8s_resource_group_name=k8s_resource_group_name, request=request, @@ -1063,6 +1630,7 @@ async def _create_or_update_resources( name=f"{k8s_resource_group_name}-endpoint-config", ) + # Create VPA if request.build_endpoint_request.optimize_costs: vpa_arguments = get_endpoint_resource_arguments_from_request( k8s_resource_group_name=k8s_resource_group_name, @@ -1077,10 +1645,33 @@ async def _create_or_update_resources( name=k8s_resource_group_name, ) - if model_endpoint_record.endpoint_type in { - ModelEndpointType.SYNC, - ModelEndpointType.STREAMING, - }: + # Create PDB + if not is_multinode: + # Only create PDB if we're not using LWS + pdb_config_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name="pod-disruption-budget", + ) + pdb_template = load_k8s_yaml("pod-disruption-budget.yaml", pdb_config_arguments) + await self._create_pdb( + pdb=pdb_template, + name=k8s_resource_group_name, + ) + + # Create HPA/Keda Scaled Object, Service (one of two types), VirtualService, DestinationRule, ServiceEntry + # as needed + if ( + model_endpoint_record.endpoint_type + in { + ModelEndpointType.SYNC, + ModelEndpointType.STREAMING, + } + and not is_multinode + ): + # Don't need HPA, keda, istio resources for LWS or async endpoints cluster_version = get_kubernetes_cluster_version() # For k8s cluster versions 1.23 - 1.25 we need to use the v2beta2 api # For 1.26+ v2beta2 has been deperecated and merged into v2 @@ -1089,33 +1680,145 @@ async def _create_or_update_resources( else: api_version = "autoscaling/v2beta2" - hpa_arguments = get_endpoint_resource_arguments_from_request( + # create exactly one of HPA or KEDA ScaledObject, depending if we request more than 0 min_workers + # Right now, keda only will support scaling from 0 to 1 + # TODO support keda scaling from 1 to N as well + if request.build_endpoint_request.min_workers > 0: + # Delete keda scaled object if it exists so exactly one of HPA or KEDA ScaledObject remains + await self._delete_keda_scaled_object( + build_endpoint_request.model_endpoint_record.id + ) + hpa_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name="horizontal-pod-autoscaler", + api_version=api_version, + ) + hpa_template = load_k8s_yaml("horizontal-pod-autoscaler.yaml", hpa_arguments) + await self._create_hpa( + hpa=hpa_template, + name=k8s_resource_group_name, + ) + else: # min workers == 0, use keda + # Delete hpa if it exists so exactly one of HPA or KEDA ScaledObject remains + await self._delete_hpa( + build_endpoint_request.model_endpoint_record.id, k8s_resource_group_name + ) + keda_scaled_object_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name="keda-scaled-object", + ) + keda_scaled_object_template = load_k8s_yaml( + "keda-scaled-object.yaml", keda_scaled_object_arguments + ) + await self._create_keda_scaled_object( + scaled_object=keda_scaled_object_template, + name=k8s_resource_group_name, + ) + + service_arguments = get_endpoint_resource_arguments_from_request( k8s_resource_group_name=k8s_resource_group_name, request=request, sqs_queue_name=sqs_queue_name_str, sqs_queue_url=sqs_queue_url_str, - endpoint_resource_name="horizontal-pod-autoscaler", - api_version=api_version, + endpoint_resource_name="service", ) - hpa_template = load_k8s_yaml("horizontal-pod-autoscaler.yaml", hpa_arguments) - await self._create_hpa( - hpa=hpa_template, - name=k8s_resource_group_name, + service_template = load_k8s_yaml("service.yaml", service_arguments) + await self._create_service( + service=service_template, + name=k8s_service_name, ) + # TODO wsong: add flag to use istio and use these arguments + if hmi_config.istio_enabled: + virtual_service_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name="virtual-service", + ) + virtual_service_template = load_k8s_yaml( + "virtual-service.yaml", virtual_service_arguments + ) + await self._create_virtual_service( + virtual_service=virtual_service_template, + name=k8s_resource_group_name, + ) + + destination_rule_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name="destination-rule", + ) + destination_rule_template = load_k8s_yaml( + "destination-rule.yaml", destination_rule_arguments + ) + await self._create_destination_rule( + destination_rule=destination_rule_template, + name=k8s_resource_group_name, + ) + elif ( + model_endpoint_record.endpoint_type + in { + ModelEndpointType.SYNC, + ModelEndpointType.STREAMING, + } + and is_multinode + ): + # Only create the service (and serviceEntry if istio is enabled) service_arguments = get_endpoint_resource_arguments_from_request( k8s_resource_group_name=k8s_resource_group_name, request=request, sqs_queue_name=sqs_queue_name_str, sqs_queue_url=sqs_queue_url_str, - endpoint_resource_name="service", + service_name_override=k8s_service_name, + endpoint_resource_name="lws-service", ) - service_template = load_k8s_yaml("service.yaml", service_arguments) + service_template = load_k8s_yaml("lws-service.yaml", service_arguments) await self._create_service( service=service_template, - name=k8s_resource_group_name, + name=k8s_service_name, ) + if hmi_config.istio_enabled: + # If Istio is enabled, we also create a ServiceEntry. This is in service of the hack + # where we manually resolve the IP address of the K8s service created above. + # We empirically need to create this in order for the request to the service's IP address + # to go through. See live_{sync,streaming}_model_endpoint_inference_gateway.py for more details. + lws_service_entry_arguments = get_endpoint_resource_arguments_from_request( + k8s_resource_group_name=k8s_resource_group_name, + request=request, + sqs_queue_name=sqs_queue_name_str, + sqs_queue_url=sqs_queue_url_str, + endpoint_resource_name="lws-service-entry", + service_name_override=k8s_service_name, + ) + lws_service_entry_template = load_k8s_yaml( + "lws-service-entry.yaml", lws_service_entry_arguments + ) + await self._create_lws_service_entry( + lws_service_entry=lws_service_entry_template, + name=k8s_resource_group_name, + ) + if model_endpoint_record.endpoint_type in { + ModelEndpointType.SYNC, + ModelEndpointType.STREAMING, + }: + return k8s_service_name + elif model_endpoint_record.endpoint_type == ModelEndpointType.ASYNC: + return sqs_queue_name_str + else: + # We should never get here + raise ValueError(f"Unsupported endpoint type {model_endpoint_record.endpoint_type}") + @staticmethod def _get_vertical_autoscaling_params( vpa_config, @@ -1159,43 +1862,116 @@ def _get_sync_autoscaling_params( per_worker=per_worker, ) + @staticmethod + def _get_sync_autoscaling_params_from_keda( + keda_config, + ) -> HorizontalAutoscalingEndpointParams: + spec = keda_config["spec"] + concurrency = 1 + for trigger in spec["triggers"]: + if trigger["metadata"].get("metricName") == "request_concurrency_average": + # Needs to match what is defined in the keda-scaled-obj section in + # service_template_config_map.yaml! + concurrency = trigger["metadata"]["threshold"] + break + return dict( + max_workers=spec.get("maxReplicaCount"), + min_workers=spec.get("minReplicaCount"), + per_worker=concurrency, + ) + async def _get_resources( self, endpoint_id: str, deployment_name: str, endpoint_type: ModelEndpointType ) -> ModelEndpointInfraState: - apps_client = get_kubernetes_apps_client() + custom_objects_client = get_kubernetes_custom_objects_client() k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + + logger.info( + f"trying to find lws at {k8s_resource_group_name}, {hmi_config.endpoint_namespace}" + ) try: - deployment_config = await apps_client.read_namespaced_deployment( - name=k8s_resource_group_name, namespace=hmi_config.endpoint_namespace + lws_config = await custom_objects_client.get_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + namespace=hmi_config.endpoint_namespace, + plural="leaderworkersets", + name=k8s_resource_group_name, ) except ApiException as e: - if e.status == 404: - logger.warning( - f"Could not find resource, falling back to legacy deployment_name: " - f"{k8s_resource_group_name=}, {endpoint_id=}, {deployment_name=}" - ) - k8s_resource_group_name = deployment_name - deployment_config = await apps_client.read_namespaced_deployment( - name=k8s_resource_group_name, - namespace=hmi_config.endpoint_namespace, - ) - else: - raise + # Need to handle the case where lws CRD isn't installed as well as the lws not existing. + logger.info(e) + lws_config = None + + # Make the call here so we can use it in both places, also this makes _get_resources_from_lws_type make zero requests to k8s + config_maps = await self._get_config_maps( + endpoint_id=endpoint_id, deployment_name=k8s_resource_group_name + ) + + if lws_config is None: + infra_state = await self._get_resources_from_deployment_type( + endpoint_id=endpoint_id, + deployment_name=deployment_name, + endpoint_type=endpoint_type, + config_maps=config_maps, + ) + else: + infra_state = await self._get_resources_from_lws_type( + endpoint_id=endpoint_id, + deployment_name=deployment_name, + endpoint_type=endpoint_type, + lws_config=lws_config, + config_maps=config_maps, + ) + return infra_state + + async def _get_resources_from_deployment_type( + self, endpoint_id: str, deployment_name: str, endpoint_type: ModelEndpointType, config_maps + ) -> ModelEndpointInfraState: + custom_objects_client = get_kubernetes_custom_objects_client() + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + + deployment_config = await self._get_deployment(endpoint_id, deployment_name) common_params = self._get_common_endpoint_params(deployment_config) if endpoint_type == ModelEndpointType.ASYNC: horizontal_autoscaling_params = self._get_async_autoscaling_params(deployment_config) elif endpoint_type in {ModelEndpointType.SYNC, ModelEndpointType.STREAMING}: autoscaling_client = get_kubernetes_autoscaling_client() - hpa_config = await autoscaling_client.read_namespaced_horizontal_pod_autoscaler( - k8s_resource_group_name, hmi_config.endpoint_namespace - ) - horizontal_autoscaling_params = self._get_sync_autoscaling_params(hpa_config) + custom_object_client = get_kubernetes_custom_objects_client() + try: + hpa_config = await autoscaling_client.read_namespaced_horizontal_pod_autoscaler( + k8s_resource_group_name, hmi_config.endpoint_namespace + ) + except ApiException as e: + if e.status == 404: + hpa_config = None + else: + raise e + try: + keda_scaled_object_config = await custom_object_client.get_namespaced_custom_object( + group="keda.sh", + version="v1alpha1", + namespace=hmi_config.endpoint_namespace, + plural="scaledobjects", + name=k8s_resource_group_name, + ) + except ApiException: + keda_scaled_object_config = None + if hpa_config is not None: + horizontal_autoscaling_params = self._get_sync_autoscaling_params(hpa_config) + elif keda_scaled_object_config is not None: + horizontal_autoscaling_params = self._get_sync_autoscaling_params_from_keda( + keda_scaled_object_config + ) + else: + raise EndpointResourceInfraException( + f"Could not find autoscaling config for {endpoint_type}" + ) else: raise ValueError(f"Unexpected endpoint type {endpoint_type}") vertical_autoscaling_params = None - custom_objects_client = get_kubernetes_custom_objects_client() + try: vpa_config = await custom_objects_client.get_namespaced_custom_object( group="autoscaling.k8s.io", @@ -1209,18 +1985,14 @@ async def _get_resources( if e.status == 404: pass - config_maps = await self._get_config_maps( - endpoint_id=endpoint_id, deployment_name=k8s_resource_group_name - ) - llm_engine_container = self._get_llm_engine_container(deployment_config) - envlist = llm_engine_container.env + launch_container = self._get_launch_container(deployment_config) + envlist = launch_container.env # Note: the env var PREWARM is either "true" or "false" string (or doesn't exist for legacy) # Convert this as early as possible to Optional[bool] to avoid bugs prewarm = str_to_bool(self._get_env_value_from_envlist(envlist, "PREWARM")) high_priority = ( - deployment_config.spec.template.spec.priority_class_name - == LLM_ENGINE_HIGH_PRIORITY_CLASS + deployment_config.spec.template.spec.priority_class_name == LAUNCH_HIGH_PRIORITY_CLASS ) infra_state = ModelEndpointInfraState( @@ -1244,6 +2016,7 @@ async def _get_resources( gpu_type=common_params["gpu_type"], # type: ignore memory=common_params["memory"], storage=common_params["storage"], + nodes_per_worker=1, # We're in "Deployment" case thus nodes_per_worker=1 optimize_costs=(vertical_autoscaling_params is not None), ), user_config_state=self._translate_k8s_config_maps_to_user_config_data( @@ -1252,6 +2025,63 @@ async def _get_resources( image=common_params["image"], num_queued_items=None, ) + + return infra_state + + async def _get_resources_from_lws_type( + self, + endpoint_id: str, + deployment_name: str, + endpoint_type: ModelEndpointType, + lws_config, + config_maps: List, + ) -> ModelEndpointInfraState: + k8s_resource_group_name = _endpoint_id_to_k8s_resource_group_name(endpoint_id) + + # Assume leader + worker share the same user-set env vars + common_params = self._get_common_endpoint_params_for_lws_type(lws_config) + + replicas = lws_config["spec"]["replicas"] + prewarm = False # not provided here + high_priority = ( + lws_config["spec"]["leaderWorkerTemplate"]["leaderTemplate"]["spec"][ + "priorityClassName" + ] + == LAUNCH_HIGH_PRIORITY_CLASS + ) + nodes_per_worker = lws_config["spec"]["leaderWorkerTemplate"]["size"] + + infra_state = ModelEndpointInfraState( + deployment_name=k8s_resource_group_name, + aws_role=common_params["aws_role"], + results_s3_bucket=common_params["results_s3_bucket"], + child_fn_info=None, + labels=common_params["labels"], + prewarm=prewarm, + high_priority=high_priority, + deployment_state=ModelEndpointDeploymentState( + min_workers=replicas, + max_workers=replicas, # We don't have any notion of autoscaling for LWS + per_worker=int(1), # TODO update this if we support LWS autoscaling + available_workers=replicas, # TODO unfortunately it doesn't look like we can get this from the LWS CRD, so this is kind of a dummy value + unavailable_workers=0, + ), + resource_state=ModelEndpointResourceState( + cpus=common_params["cpus"], + gpus=common_params["gpus"], + gpu_type=common_params["gpu_type"], # type: ignore + memory=common_params["memory"], + storage=common_params["storage"], + nodes_per_worker=nodes_per_worker, + optimize_costs=False, + ), + user_config_state=self._translate_k8s_config_maps_to_user_config_data( + k8s_resource_group_name, config_maps + ), + image=common_params["image"], + num_queued_items=None, + ) + return infra_state async def _get_all_resources( @@ -1282,38 +2112,77 @@ async def _get_all_resources( vpas = [] else: raise + try: + keda_scaled_objects = ( + await custom_objects_client.list_namespaced_custom_object( + group="keda.sh", + version="v1alpha1", + namespace=hmi_config.endpoint_namespace, + plural="scaledobjects", + ) + )["items"] + except ApiException as e: + if e.status == 404: + keda_scaled_objects = [] + else: + raise + + try: + leader_worker_sets = ( + await custom_objects_client.list_namespaced_custom_object( + group="leaderworkerset.x-k8s.io", + version="v1", + namespace=hmi_config.endpoint_namespace, + plural="leaderworkersets", + ) + )["items"] + except ApiException as e: + if e.status == 404: + leader_worker_sets = [] + else: + raise deployments_by_name = {deployment.metadata.name: deployment for deployment in deployments} hpas_by_name = {hpa.metadata.name: hpa for hpa in hpas} vpas_by_name = {vpa["metadata"]["name"]: vpa for vpa in vpas} + keda_scaled_objects_by_name = {kso["metadata"]["name"]: kso for kso in keda_scaled_objects} + leader_worker_sets_by_name = {lws["metadata"]["name"]: lws for lws in leader_worker_sets} all_config_maps = await self._get_all_config_maps() - # can safely assume hpa with same name as deployment corresponds to the same LLMEngine Endpoint + # can safely assume hpa with same name as deployment corresponds to the same Launch Endpoint logger.info(f"Orphaned hpas: {set(hpas_by_name).difference(set(deployments_by_name))}") logger.info(f"Orphaned vpas: {set(vpas_by_name).difference(set(deployments_by_name))}") infra_states = {} - logger.info(f"Got data for {list(deployments_by_name.keys())}") + logger.info( + f"Got data for {list(deployments_by_name.keys())} and {list(leader_worker_sets_by_name.keys())}" + ) for name, deployment_config in deployments_by_name.items(): try: hpa_config = hpas_by_name.get(name, None) vpa_config = vpas_by_name.get(name, None) + keda_scaled_object_config = keda_scaled_objects_by_name.get(name, None) common_params = self._get_common_endpoint_params(deployment_config) - llm_engine_container = self._get_llm_engine_container(deployment_config) + launch_container = self._get_launch_container(deployment_config) - envlist = llm_engine_container.env + envlist = launch_container.env # Convert as early as possible to Optional[bool] to avoid bugs prewarm = str_to_bool(self._get_env_value_from_envlist(envlist, "PREWARM")) high_priority = ( deployment_config.spec.template.spec.priority_class_name - == LLM_ENGINE_HIGH_PRIORITY_CLASS + == LAUNCH_HIGH_PRIORITY_CLASS ) if hpa_config: # Assume it's a sync endpoint # TODO I think this is correct but only barely, it introduces a coupling between - # an HPA existing and an endpoint being a sync endpoint. The "more correct" + # an HPA (or keda SO) existing and an endpoint being a sync endpoint. The "more correct" # thing to do is to query the db to get the endpoints, but it doesn't belong here horizontal_autoscaling_params = self._get_sync_autoscaling_params(hpa_config) + elif keda_scaled_object_config: + # Also assume it's a sync endpoint + horizontal_autoscaling_params = self._get_sync_autoscaling_params_from_keda( + keda_scaled_object_config + ) else: horizontal_autoscaling_params = self._get_async_autoscaling_params( deployment_config @@ -1342,6 +2211,7 @@ async def _get_all_resources( gpu_type=common_params["gpu_type"], # type: ignore memory=common_params["memory"], storage=common_params["storage"], + nodes_per_worker=1, # We're in a Deployment case, so nodes_per_worker is 1 optimize_costs=(vertical_autoscaling_params is not None), ), user_config_state=self._translate_k8s_config_maps_to_user_config_data( @@ -1350,7 +2220,7 @@ async def _get_all_resources( image=common_params["image"], num_queued_items=None, ) - if name.startswith("llm-engine-endpoint-id-"): + if name.startswith("launch-endpoint-id-"): key = _k8s_resource_group_name_to_endpoint_id(name) is_key_an_endpoint_id = True else: @@ -1360,9 +2230,27 @@ async def _get_all_resources( infra_states[key] = (is_key_an_endpoint_id, infra_state) except Exception: logger.exception(f"Error parsing deployment {name}") + for name, lws_config in leader_worker_sets_by_name.items(): + # name.startswith("launch-endpoint-id-") should always be true, the other case is a legacy. + key = _k8s_resource_group_name_to_endpoint_id(name) + is_key_an_endpoint_id = True + endpoint_id = key + deployment_name = name + endpoint_type = ( + ModelEndpointType.STREAMING + ) # TODO change if we ever support other endpoint types + infra_states[key] = ( + is_key_an_endpoint_id, + await self._get_resources_from_lws_type( + endpoint_id, deployment_name, endpoint_type, lws_config, all_config_maps + ), + ) return infra_states async def _delete_resources_async(self, endpoint_id: str, deployment_name: str) -> bool: + + # TODO check that this implementation actually works for multinode if/when we decide to support that + lws_delete_succeeded = await self._delete_lws(endpoint_id=endpoint_id) deployment_delete_succeeded = await self._delete_deployment( endpoint_id=endpoint_id, deployment_name=deployment_name ) @@ -1370,9 +2258,12 @@ async def _delete_resources_async(self, endpoint_id: str, deployment_name: str) endpoint_id=endpoint_id, deployment_name=deployment_name ) await self._delete_vpa(endpoint_id=endpoint_id) - return deployment_delete_succeeded and config_map_delete_succeeded + await self._delete_pdb(endpoint_id=endpoint_id) + return (deployment_delete_succeeded or lws_delete_succeeded) and config_map_delete_succeeded async def _delete_resources_sync(self, endpoint_id: str, deployment_name: str) -> bool: + lws_delete_succeeded = await self._delete_lws(endpoint_id=endpoint_id) + deployment_delete_succeeded = await self._delete_deployment( endpoint_id=endpoint_id, deployment_name=deployment_name, @@ -1383,10 +2274,19 @@ async def _delete_resources_sync(self, endpoint_id: str, deployment_name: str) - service_delete_succeeded = await self._delete_service( endpoint_id=endpoint_id, deployment_name=deployment_name ) + lws_service_delete_succeeded = await self._delete_lws_service( + endpoint_id=endpoint_id, deployment_name=deployment_name + ) + # we should have created exactly one of an HPA or a keda scaled object hpa_delete_succeeded = await self._delete_hpa( endpoint_id=endpoint_id, deployment_name=deployment_name ) + keda_scaled_object_succeeded = await self._delete_keda_scaled_object( + endpoint_id=endpoint_id + ) await self._delete_vpa(endpoint_id=endpoint_id) + await self._delete_pdb(endpoint_id=endpoint_id) + await self._delete_lws_service_entry(endpoint_id=endpoint_id) destination_rule_delete_succeeded = await self._delete_destination_rule( endpoint_id=endpoint_id @@ -1399,7 +2299,7 @@ async def _delete_resources_sync(self, endpoint_id: str, deployment_name: str) - deployment_delete_succeeded and config_map_delete_succeeded and service_delete_succeeded - and hpa_delete_succeeded + and (hpa_delete_succeeded or keda_scaled_object_succeeded) and destination_rule_delete_succeeded and virtual_service_delete_succeeded - ) + ) or (lws_delete_succeeded and config_map_delete_succeeded and lws_service_delete_succeeded) diff --git a/server/llm_engine_server/infra/gateways/resources/k8s_resource_types.py b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py similarity index 77% rename from server/llm_engine_server/infra/gateways/resources/k8s_resource_types.py rename to model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py index f140fd5d..c1c64c34 100644 --- a/server/llm_engine_server/infra/gateways/resources/k8s_resource_types.py +++ b/model-engine/model_engine_server/infra/gateways/resources/k8s_resource_types.py @@ -1,31 +1,30 @@ import hashlib -import json +import os from datetime import datetime from typing import Any, Dict, List, Optional, Sequence, TypedDict, Union -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.model_endpoints import BrokerName, BrokerType -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.common.env_vars import CIRCLECI -from llm_engine_server.common.resource_limits import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.model_endpoints import BrokerName, BrokerType +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.env_vars import CIRCLECI, GIT_TAG +from model_engine_server.common.resource_limits import ( FORWARDER_CPU_USAGE, FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_USAGE, + FORWARDER_WORKER_COUNT, ) -from llm_engine_server.common.serialization_utils import bool_to_str, python_json_to_b64 -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.domain.entities import ( - ArtifactLike, +from model_engine_server.common.serialization_utils import python_json_to_b64 +from model_engine_server.core.config import infra_config +from model_engine_server.domain.entities import ( ModelEndpointConfig, RunnableImageLike, StreamingEnhancedRunnableImageFlavor, TritonEnhancedRunnableImageFlavor, - ZipArtifactFlavor, ) -from llm_engine_server.domain.use_cases.model_endpoint_use_cases import ( +from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( CONVERTED_FROM_ARTIFACT_LIKE_KEY, ) -from llm_engine_server.infra.gateways.k8s_resource_parser import ( +from model_engine_server.infra.gateways.k8s_resource_parser import ( get_node_port, get_target_concurrency_from_per_worker_value, ) @@ -33,10 +32,6 @@ __all__: Sequence[str] = ( "BatchJobOrchestrationJobArguments", "CommonEndpointParams", - "DeploymentArtifactAsyncCpuArguments", - "DeploymentArtifactAsyncGpuArguments", - "DeploymentArtifactSyncCpuArguments", - "DeploymentArtifactSyncGpuArguments", "DeploymentRunnableImageAsyncCpuArguments", "DeploymentRunnableImageAsyncGpuArguments", "DeploymentRunnableImageStreamingCpuArguments", @@ -47,6 +42,7 @@ "DeploymentTritonEnhancedRunnableImageAsyncGpuArguments", "DeploymentTritonEnhancedRunnableImageSyncCpuArguments", "DeploymentTritonEnhancedRunnableImageSyncGpuArguments", + "DestinationRuleArguments", "DictStrInt", "DictStrStr", "DockerImageBatchJobCpuArguments", @@ -56,25 +52,26 @@ "HorizontalAutoscalingEndpointParams", "HorizontalPodAutoscalerArguments", "ImageCacheArguments", - "LLM_ENGINE_DEFAULT_PRIORITY_CLASS", - "LLM_ENGINE_HIGH_PRIORITY_CLASS", + "CronTriggerArguments", + "LAUNCH_DEFAULT_PRIORITY_CLASS", + "LAUNCH_HIGH_PRIORITY_CLASS", "ResourceArguments", "ServiceArguments", "UserConfigArguments", "VerticalAutoscalingEndpointParams", "VerticalPodAutoscalerArguments", + "VirtualServiceArguments", "get_endpoint_resource_arguments_from_request", ) -# Constants for LLMEngine priority classes -LLM_ENGINE_HIGH_PRIORITY_CLASS = "llm-engine-high-priority" -LLM_ENGINE_DEFAULT_PRIORITY_CLASS = "llm-engine-default-priority" +# Constants for Launch priority classes +LAUNCH_HIGH_PRIORITY_CLASS = "model-engine-high-priority" +LAUNCH_DEFAULT_PRIORITY_CLASS = "model-engine-default-priority" -KUBERNETES_MAX_LENGTH = 64 +IMAGE_HASH_MAX_LENGTH = 32 FORWARDER_PORT = 5000 USER_CONTAINER_PORT = 5005 ARTIFACT_LIKE_CONTAINER_PORT = FORWARDER_PORT -FORWARDER_IMAGE_TAG = "54f8f73bfb1cce62a2b42326ccf9f49b5b145126" class _BaseResourceArguments(TypedDict): @@ -86,6 +83,7 @@ class _BaseResourceArguments(TypedDict): PRODUCT: str CREATED_BY: str OWNER: str + GIT_TAG: str class _BaseEndpointArguments(_BaseResourceArguments): @@ -109,7 +107,7 @@ class _BaseDeploymentArguments(_BaseEndpointArguments): PRIORITY: str IMAGE: str IMAGE_HASH: str - DATADOG_TRACE_ENABLED: str + DD_TRACE_ENABLED: str CPUS: str MEMORY: str STORAGE_DICT: DictStrStr @@ -140,6 +138,7 @@ class _SyncRunnableImageDeploymentArguments(TypedDict): """Keyword-arguments for substituting into sync deployment templates.""" FORWARDER_PORT: int + FORWARDER_WORKER_COUNT: int class _StreamingDeploymentArguments(TypedDict): @@ -147,17 +146,7 @@ class _StreamingDeploymentArguments(TypedDict): FORWARDER_PORT: int STREAMING_PREDICT_ROUTE: str - - -class _ArtifactDeploymentArguments(_BaseDeploymentArguments): - """Keyword-arguments for substituting into artifact deployment templates.""" - - BUNDLE_URL: str - BASE_PATH: str - LOAD_PREDICT_FN_MODULE_PATH: str - LOAD_MODEL_FN_MODULE_PATH: str - CHILD_FN_INFO: str - PREWARM: str + FORWARDER_WORKER_COUNT: int class _RunnableImageDeploymentArguments(_BaseDeploymentArguments): @@ -169,11 +158,11 @@ class _RunnableImageDeploymentArguments(_BaseDeploymentArguments): HEALTHCHECK_ROUTE: str READINESS_INITIAL_DELAY: int INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH: str - FORWARDER_IMAGE_TAG: str FORWARDER_CONFIG_FILE_NAME: str FORWARDER_CPUS_LIMIT: float FORWARDER_MEMORY_LIMIT: str FORWARDER_STORAGE_LIMIT: str + FORWARDER_EXTRA_ROUTES: List[str] USER_CONTAINER_PORT: int @@ -183,6 +172,7 @@ class _JobArguments(_BaseResourceArguments): JOB_ID: str BATCH_JOB_MAX_RUNTIME: int BATCH_JOB_TTL_SECONDS_AFTER_FINISHED: int + REQUEST_ID: str class _DockerImageBatchJobArguments(_JobArguments): @@ -199,6 +189,7 @@ class _DockerImageBatchJobArguments(_JobArguments): LOCAL_FILE_NAME: str FILE_CONTENTS_B64ENCODED: str COMMAND: List[str] + BATCH_JOB_NUM_WORKERS: int class _GpuArguments(TypedDict): @@ -220,26 +211,10 @@ class _TritonArguments(TypedDict): TRITON_COMMIT_TAG: str -class DeploymentArtifactAsyncCpuArguments(_ArtifactDeploymentArguments, _AsyncDeploymentArguments): - """Keyword-arguments for substituting into CPU async deployment templates with artifacts.""" - - -class DeploymentArtifactAsyncGpuArguments( - _ArtifactDeploymentArguments, _AsyncDeploymentArguments, _GpuArguments -): - """Keyword-arguments for substituting into GPU async deployment templates with artifacts.""" - - -class DeploymentArtifactSyncCpuArguments( - _ArtifactDeploymentArguments, _SyncArtifactDeploymentArguments -): - """Keyword-arguments for substituting into CPU sync deployment templates with artifacts.""" - - -class DeploymentArtifactSyncGpuArguments( - _ArtifactDeploymentArguments, _SyncArtifactDeploymentArguments, _GpuArguments -): - """Keyword-arguments for substituting into GPU sync deployment templates with artifacts.""" +class _LeaderWorkerSetArguments(TypedDict): + LWS_SIZE: int + WORKER_COMMAND: List[str] + WORKER_ENV: List[Dict[str, Any]] class DeploymentRunnableImageSyncCpuArguments( @@ -320,6 +295,15 @@ class DeploymentTritonEnhancedRunnableImageAsyncGpuArguments( """ +class LeaderWorkerSetRunnableImageStreamingGpuArguments( + _RunnableImageDeploymentArguments, + _StreamingDeploymentArguments, + _GpuArguments, + _LeaderWorkerSetArguments, +): + """Keyword-arguments for substituting into GPU streaming LeaderWorkerSet templates for runnable images.""" + + class HorizontalPodAutoscalerArguments(_BaseEndpointArguments): """Keyword-arguments for substituting into horizontal pod autoscaler templates.""" @@ -329,6 +313,17 @@ class HorizontalPodAutoscalerArguments(_BaseEndpointArguments): API_VERSION: str +class KedaScaledObjectArguments(_BaseEndpointArguments): + MIN_WORKERS: int + MAX_WORKERS: int + CONCURRENCY: float + REDIS_HOST_PORT: str + REDIS_DB_INDEX: str + SERVICEBUS_NAMESPACE: Optional[str] + AUTHENTICATION_REF: str + PROMETHEUS_SERVER_ADDRESS: str + + class UserConfigArguments(_BaseEndpointArguments): """Keyword-arguments for substituting into user-config templates.""" @@ -355,6 +350,17 @@ class ServiceArguments(_BaseEndpointArguments): NODE_PORT_DICT: DictStrInt +class LwsServiceArguments(ServiceArguments): + """Keyword-arguments for substituting into service templates for LWS. + Need this to override the service name for LWS.""" + + SERVICE_NAME_OVERRIDE: str + + +class DestinationRuleArguments(_BaseEndpointArguments): + """Keyword-arguments for substituting into destination-rule templates.""" + + class VerticalPodAutoscalerArguments(_BaseEndpointArguments): """Keyword-arguments for substituting into vertical pod autoscaler templates.""" @@ -362,6 +368,24 @@ class VerticalPodAutoscalerArguments(_BaseEndpointArguments): MEMORY: str +class PodDisruptionBudgetArguments(_BaseEndpointArguments): + """Keyword-arguments for substituting into pod disruption budget templates.""" + + pass + + +class VirtualServiceArguments(_BaseEndpointArguments): + """Keyword-arguments for substituting into virtual-service templates.""" + + DNS_HOST_DOMAIN: str + + +class LwsServiceEntryArguments(_BaseEndpointArguments): + """Keyword-arguments for substituting into istio service-entry templates to support LWS.""" + + SERVICE_NAME_OVERRIDE: str + + class BatchJobOrchestrationJobArguments(_JobArguments): """Keyword-arguments for substituting into batch-job-orchestration-job templates.""" @@ -385,6 +409,23 @@ class ImageCacheArguments(TypedDict): NAMESPACE: str +class CronTriggerArguments(TypedDict): + """Keyword-arguments for substituting into cronjob trigger templates.""" + + HOST: str + NAME: str + CREATED_BY: str + OWNER: str + TEAM: str + PRODUCT: str + TRIGGER_ID: str + CRON_SCHEDULE: str + DOCKER_IMAGE_BATCH_JOB_BUNDLE_ID: str + JOB_CONFIG: str + JOB_METADATA: str + BATCH_CURL_JOB_ACTIVE_DEADLINE_SECONDS: int + + class CommonEndpointParams(TypedDict): cpus: str memory: str @@ -412,10 +453,6 @@ class VerticalAutoscalingEndpointParams(TypedDict): EndpointResourceArguments = Union[ - DeploymentArtifactAsyncCpuArguments, - DeploymentArtifactAsyncGpuArguments, - DeploymentArtifactSyncCpuArguments, - DeploymentArtifactSyncGpuArguments, DeploymentRunnableImageAsyncCpuArguments, DeploymentRunnableImageAsyncGpuArguments, DeploymentRunnableImageStreamingCpuArguments, @@ -426,11 +463,13 @@ class VerticalAutoscalingEndpointParams(TypedDict): DeploymentTritonEnhancedRunnableImageAsyncGpuArguments, DeploymentTritonEnhancedRunnableImageSyncCpuArguments, DeploymentTritonEnhancedRunnableImageSyncGpuArguments, + DestinationRuleArguments, EndpointConfigArguments, HorizontalPodAutoscalerArguments, ServiceArguments, UserConfigArguments, VerticalPodAutoscalerArguments, + VirtualServiceArguments, ] ResourceArguments = Union[ @@ -439,16 +478,20 @@ class VerticalAutoscalingEndpointParams(TypedDict): DockerImageBatchJobCpuArguments, DockerImageBatchJobGpuArguments, ImageCacheArguments, + CronTriggerArguments, ] +def compute_image_hash(image: str) -> str: + return str(hashlib.sha256(str(image).encode()).hexdigest())[:IMAGE_HASH_MAX_LENGTH] + + def container_start_triton_cmd( triton_model_repository: str, - triton_model_replicas: Dict[str, int], + triton_model_replicas: Union[Dict[str, int], Dict[str, str]], ipv6_healthcheck: bool = False, ) -> List[str]: - # NOTE: this path is set in the Trtion-specific Dockerfile: - # std-ml-srv/ml_serve/triton/Dockerfile + # NOTE: this path is set in the Triton-specific Dockerfile: triton_start_command: List[str] = [ "python", "/install/tritonserver.py", @@ -469,6 +512,7 @@ def get_endpoint_resource_arguments_from_request( sqs_queue_url: str, endpoint_resource_name: str, api_version: str = "", + service_name_override: Optional[str] = None, ) -> EndpointResourceArguments: """Get the arguments for the endpoint resource templates from the request. @@ -485,15 +529,10 @@ def get_endpoint_resource_arguments_from_request( team = k8s_labels.get("team", "") product = k8s_labels.get("product", "") storage = build_endpoint_request.storage - prewarm = bool_to_str(build_endpoint_request.prewarm) or "false" - sqs_profile = "default" # TODO: Make this configurable - s3_bucket = ml_infra_config().s3_bucket + sqs_profile = f"eks-{infra_config().profile_ml_worker}" # TODO: Make this configurable + s3_bucket = infra_config().s3_bucket - load_predict_fn_module_path = "" - load_model_fn_module_path = "" - if isinstance(flavor, ZipArtifactFlavor): - load_predict_fn_module_path = flavor.load_predict_fn_module_path - load_model_fn_module_path = flavor.load_model_fn_module_path + service_name_override = service_name_override or k8s_resource_group_name storage_dict = DictStrStr("") if storage is not None: @@ -501,27 +540,51 @@ def get_endpoint_resource_arguments_from_request( change_cause_message = ( f"Deployment at {datetime.utcnow()} UTC. " - f"Using deployment constructed from model bundle ID: {model_bundle.id}, " - f"model bundle name: {model_bundle.name}, " - f"endpoint ID: {model_endpoint_record.id}" + f"Using deployment constructed from model bundle ID {model_bundle.id}, " + f"model bundle name {model_bundle.name}, " + f"endpoint ID {model_endpoint_record.id}" ) - priority = LLM_ENGINE_DEFAULT_PRIORITY_CLASS + priority = LAUNCH_DEFAULT_PRIORITY_CLASS if build_endpoint_request.high_priority: - priority = LLM_ENGINE_HIGH_PRIORITY_CLASS + priority = LAUNCH_HIGH_PRIORITY_CLASS - image_hash = str(hashlib.md5(str(request.image).encode()).hexdigest())[:KUBERNETES_MAX_LENGTH] + image_hash = compute_image_hash(request.image) # In Circle CI, we use Redis on localhost instead of SQS - broker_name = BrokerName.SQS.value if not CIRCLECI else BrokerName.REDIS.value - broker_type = BrokerType.SQS.value if not CIRCLECI else BrokerType.REDIS.value - datadog_trace_enabled = hmi_config.datadog_trace_enabled - if broker_type == BrokerType.REDIS.value: + if CIRCLECI: + broker_name = BrokerName.REDIS.value + broker_type = BrokerType.REDIS.value + elif infra_config().cloud_provider == "azure": + broker_name = BrokerName.SERVICEBUS.value + broker_type = BrokerType.SERVICEBUS.value + else: + broker_name = BrokerName.SQS.value + broker_type = BrokerType.SQS.value + dd_trace_enabled = hmi_config.dd_trace_enabled + if broker_type != BrokerType.SQS.value: sqs_queue_url = "" main_env = [] if isinstance(flavor, RunnableImageLike) and flavor.env: main_env = [{"name": key, "value": value} for key, value in flavor.env.items()] + main_env.append({"name": "AWS_PROFILE", "value": build_endpoint_request.aws_role}) + # NOTE: /opt/.aws/config is where service_template_config_map.yaml mounts the AWS config file, point to the mount for boto clients + main_env.append({"name": "AWS_CONFIG_FILE", "value": "/opt/.aws/config"}) + abs_account_name = os.getenv("ABS_ACCOUNT_NAME") + if abs_account_name is not None: + main_env.append({"name": "ABS_ACCOUNT_NAME", "value": abs_account_name}) + + # LeaderWorkerSet exclusive + worker_env = None + if isinstance(flavor, RunnableImageLike) and flavor.worker_env is not None: + worker_env = [{"name": key, "value": value} for key, value in flavor.worker_env.items()] + worker_env.append({"name": "AWS_PROFILE", "value": build_endpoint_request.aws_role}) + worker_env.append({"name": "AWS_CONFIG_FILE", "value": "/opt/.aws/config"}) + + worker_command = None + if isinstance(flavor, RunnableImageLike) and flavor.worker_command is not None: + worker_command = flavor.worker_command infra_service_config_volume_mount_path = "/infra-config" forwarder_config_file_name = "service--forwarder.yaml" @@ -538,7 +601,9 @@ def get_endpoint_resource_arguments_from_request( raise ValueError( "flavor.env['BASE_PATH'] is required for runnable image converted from artifact like bundle" ) - infra_service_config_volume_mount_path = f"{flavor.env['BASE_PATH']}/ml_infra_core/llm_engine_server.core/llm_engine_server.core/configs" + infra_service_config_volume_mount_path = ( + f"{flavor.env['BASE_PATH']}/model-engine/model_engine_server/core/configs" + ) forwarder_config_file_name = "service--forwarder-runnable-img-converted-from-artifact.yaml" triton_command = "" @@ -568,13 +633,14 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -588,13 +654,13 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Async Deployment Arguments CELERY_S3_BUCKET=s3_bucket, QUEUE=sqs_queue_name, @@ -616,13 +682,14 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -636,13 +703,13 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Async Deployment Arguments CELERY_S3_BUCKET=s3_bucket, QUEUE=sqs_queue_name, @@ -666,13 +733,14 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -687,15 +755,16 @@ def get_endpoint_resource_arguments_from_request( STREAMING_PREDICT_ROUTE=flavor.streaming_predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Streaming Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, ) elif endpoint_resource_name == "deployment-runnable-image-streaming-gpu": assert isinstance(flavor, StreamingEnhancedRunnableImageFlavor) @@ -710,13 +779,14 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -731,15 +801,16 @@ def get_endpoint_resource_arguments_from_request( STREAMING_PREDICT_ROUTE=flavor.streaming_predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Streaming Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, # GPU Deployment Arguments GPU_TYPE=build_endpoint_request.gpu_type.value, GPUS=build_endpoint_request.gpus, @@ -756,13 +827,14 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -776,15 +848,16 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Sync Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, ) elif endpoint_resource_name == "deployment-runnable-image-sync-gpu": assert isinstance(flavor, RunnableImageLike) @@ -799,13 +872,14 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -819,15 +893,16 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Sync Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, # GPU Deployment Arguments GPU_TYPE=build_endpoint_request.gpu_type.value, GPUS=build_endpoint_request.gpus, @@ -844,13 +919,14 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -864,13 +940,13 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Async Deployment Arguments CELERY_S3_BUCKET=s3_bucket, QUEUE=sqs_queue_name, @@ -900,13 +976,14 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -920,13 +997,13 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Async Deployment Arguments CELERY_S3_BUCKET=s3_bucket, QUEUE=sqs_queue_name, @@ -958,13 +1035,14 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -978,15 +1056,16 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Sync Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, # Triton Deployment Arguments TRITON_MODEL_REPOSITORY=flavor.triton_model_repository, TRITON_CPUS=str(flavor.triton_num_cpu), @@ -1009,13 +1088,14 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, @@ -1029,15 +1109,16 @@ def get_endpoint_resource_arguments_from_request( PREDICT_ROUTE=flavor.predict_route, HEALTHCHECK_ROUTE=flavor.healthcheck_route, READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, - FORWARDER_IMAGE_TAG=FORWARDER_IMAGE_TAG, INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, # Sync Deployment Arguments FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, # GPU Deployment Arguments GPU_TYPE=build_endpoint_request.gpu_type.value, GPUS=build_endpoint_request.gpus, @@ -1050,138 +1131,12 @@ def get_endpoint_resource_arguments_from_request( TRITON_COMMAND=triton_command, TRITON_COMMIT_TAG=flavor.triton_commit_tag, ) - elif endpoint_resource_name == "deployment-artifact-async-cpu": - assert isinstance(flavor, ArtifactLike) - return DeploymentArtifactAsyncCpuArguments( - # Base resource arguments - RESOURCE_NAME=k8s_resource_group_name, - NAMESPACE=hmi_config.endpoint_namespace, - ENDPOINT_ID=model_endpoint_record.id, - ENDPOINT_NAME=model_endpoint_record.name, - TEAM=team, - PRODUCT=product, - CREATED_BY=created_by, - OWNER=owner, - # Base deployment arguments - CHANGE_CAUSE_MESSAGE=change_cause_message, - AWS_ROLE=build_endpoint_request.aws_role, - PRIORITY=priority, - IMAGE=request.image, - IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, - CPUS=str(build_endpoint_request.cpus), - MEMORY=str(build_endpoint_request.memory), - STORAGE_DICT=storage_dict, - BASE_PATH="/app", - PER_WORKER=build_endpoint_request.per_worker, - MIN_WORKERS=build_endpoint_request.min_workers, - MAX_WORKERS=build_endpoint_request.max_workers, - RESULTS_S3_BUCKET=s3_bucket, - # Artifact Arguments - BUNDLE_URL=flavor.location, - LOAD_PREDICT_FN_MODULE_PATH=load_predict_fn_module_path, - LOAD_MODEL_FN_MODULE_PATH=load_model_fn_module_path, - CHILD_FN_INFO=json.dumps( - build_endpoint_request.child_fn_info if build_endpoint_request.child_fn_info else {} - ), - PREWARM=prewarm, - # Async Deployment Arguments - CELERY_S3_BUCKET=s3_bucket, - QUEUE=sqs_queue_name, - BROKER_NAME=broker_name, - BROKER_TYPE=broker_type, - SQS_QUEUE_URL=sqs_queue_url, - SQS_PROFILE=sqs_profile, - ) - elif endpoint_resource_name == "deployment-artifact-async-gpu": - assert isinstance(flavor, ArtifactLike) - assert build_endpoint_request.gpu_type is not None - return DeploymentArtifactAsyncGpuArguments( - # Base resource arguments - RESOURCE_NAME=k8s_resource_group_name, - NAMESPACE=hmi_config.endpoint_namespace, - ENDPOINT_ID=model_endpoint_record.id, - ENDPOINT_NAME=model_endpoint_record.name, - TEAM=team, - PRODUCT=product, - CREATED_BY=created_by, - OWNER=owner, - # Base deployment arguments - CHANGE_CAUSE_MESSAGE=change_cause_message, - AWS_ROLE=build_endpoint_request.aws_role, - PRIORITY=priority, - IMAGE=request.image, - IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, - CPUS=str(build_endpoint_request.cpus), - MEMORY=str(build_endpoint_request.memory), - STORAGE_DICT=storage_dict, - BASE_PATH="/app", - PER_WORKER=build_endpoint_request.per_worker, - MIN_WORKERS=build_endpoint_request.min_workers, - MAX_WORKERS=build_endpoint_request.max_workers, - RESULTS_S3_BUCKET=s3_bucket, - # Artifact Arguments - BUNDLE_URL=flavor.location, - LOAD_PREDICT_FN_MODULE_PATH=load_predict_fn_module_path, - LOAD_MODEL_FN_MODULE_PATH=load_model_fn_module_path, - CHILD_FN_INFO=json.dumps( - build_endpoint_request.child_fn_info if build_endpoint_request.child_fn_info else {} - ), - PREWARM=prewarm, - # Async Deployment Arguments - CELERY_S3_BUCKET=s3_bucket, - QUEUE=sqs_queue_name, - BROKER_NAME=broker_name, - BROKER_TYPE=broker_type, - SQS_QUEUE_URL=sqs_queue_url, - SQS_PROFILE=sqs_profile, - # GPU Deployment Arguments - GPU_TYPE=build_endpoint_request.gpu_type.value, - GPUS=build_endpoint_request.gpus, - ) - elif endpoint_resource_name == "deployment-artifact-sync-cpu": - assert isinstance(flavor, ArtifactLike) - return DeploymentArtifactSyncCpuArguments( - # Base resource arguments - RESOURCE_NAME=k8s_resource_group_name, - NAMESPACE=hmi_config.endpoint_namespace, - ENDPOINT_ID=model_endpoint_record.id, - ENDPOINT_NAME=model_endpoint_record.name, - TEAM=team, - PRODUCT=product, - CREATED_BY=created_by, - OWNER=owner, - # Base deployment arguments - CHANGE_CAUSE_MESSAGE=change_cause_message, - AWS_ROLE=build_endpoint_request.aws_role, - PRIORITY=priority, - IMAGE=request.image, - IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, - CPUS=str(build_endpoint_request.cpus), - MEMORY=str(build_endpoint_request.memory), - STORAGE_DICT=storage_dict, - BASE_PATH="/app", - PER_WORKER=build_endpoint_request.per_worker, - MIN_WORKERS=build_endpoint_request.min_workers, - MAX_WORKERS=build_endpoint_request.max_workers, - RESULTS_S3_BUCKET=s3_bucket, - # Artifact Arguments - BUNDLE_URL=flavor.location, - LOAD_PREDICT_FN_MODULE_PATH=load_predict_fn_module_path, - LOAD_MODEL_FN_MODULE_PATH=load_model_fn_module_path, - CHILD_FN_INFO=json.dumps( - build_endpoint_request.child_fn_info if build_endpoint_request.child_fn_info else {} - ), - PREWARM=prewarm, - # Sync Artifact DeploymentArguments Arguments - ARTIFACT_LIKE_CONTAINER_PORT=ARTIFACT_LIKE_CONTAINER_PORT, - ) - elif endpoint_resource_name == "deployment-artifact-sync-gpu": - assert isinstance(flavor, ArtifactLike) + elif endpoint_resource_name == "leader-worker-set-streaming-gpu": + assert isinstance(flavor, StreamingEnhancedRunnableImageFlavor) assert build_endpoint_request.gpu_type is not None - return DeploymentArtifactSyncGpuArguments( + assert worker_command is not None + assert worker_env is not None + return LeaderWorkerSetRunnableImageStreamingGpuArguments( # Base resource arguments RESOURCE_NAME=k8s_resource_group_name, NAMESPACE=hmi_config.endpoint_namespace, @@ -1191,34 +1146,45 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, # Base deployment arguments CHANGE_CAUSE_MESSAGE=change_cause_message, AWS_ROLE=build_endpoint_request.aws_role, PRIORITY=priority, IMAGE=request.image, IMAGE_HASH=image_hash, - DATADOG_TRACE_ENABLED=datadog_trace_enabled, + DD_TRACE_ENABLED=str(dd_trace_enabled), CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), STORAGE_DICT=storage_dict, - BASE_PATH="/app", PER_WORKER=build_endpoint_request.per_worker, MIN_WORKERS=build_endpoint_request.min_workers, MAX_WORKERS=build_endpoint_request.max_workers, RESULTS_S3_BUCKET=s3_bucket, - # Artifact Arguments - BUNDLE_URL=flavor.location, - LOAD_PREDICT_FN_MODULE_PATH=load_predict_fn_module_path, - LOAD_MODEL_FN_MODULE_PATH=load_model_fn_module_path, - CHILD_FN_INFO=json.dumps( - build_endpoint_request.child_fn_info if build_endpoint_request.child_fn_info else {} - ), - PREWARM=prewarm, - # Sync Artifact DeploymentArguments Arguments - ARTIFACT_LIKE_CONTAINER_PORT=ARTIFACT_LIKE_CONTAINER_PORT, - # GPU Deployment Arguments + # Runnable Image Arguments + MAIN_ENV=main_env, + COMMAND=flavor.streaming_command, + PREDICT_ROUTE=flavor.predict_route, + STREAMING_PREDICT_ROUTE=flavor.streaming_predict_route, + HEALTHCHECK_ROUTE=flavor.healthcheck_route, + READINESS_INITIAL_DELAY=flavor.readiness_initial_delay_seconds, + INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH=infra_service_config_volume_mount_path, + FORWARDER_CONFIG_FILE_NAME=forwarder_config_file_name, + FORWARDER_CPUS_LIMIT=FORWARDER_CPU_USAGE, + FORWARDER_MEMORY_LIMIT=FORWARDER_MEMORY_USAGE, + FORWARDER_STORAGE_LIMIT=FORWARDER_STORAGE_USAGE, + USER_CONTAINER_PORT=USER_CONTAINER_PORT, + FORWARDER_EXTRA_ROUTES=flavor.extra_routes, + # Streaming Arguments + FORWARDER_PORT=FORWARDER_PORT, + FORWARDER_WORKER_COUNT=FORWARDER_WORKER_COUNT, + # GPU Arguments GPU_TYPE=build_endpoint_request.gpu_type.value, GPUS=build_endpoint_request.gpus, + # Leader Worker Set Arguments + LWS_SIZE=build_endpoint_request.nodes_per_worker, + WORKER_COMMAND=worker_command, + WORKER_ENV=worker_env, ) elif endpoint_resource_name == "user-config": app_config_serialized = python_json_to_b64(model_bundle.app_config) @@ -1232,6 +1198,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, CONFIG_DATA_SERIALIZED=app_config_serialized, ) elif endpoint_resource_name == "endpoint-config": @@ -1240,8 +1207,14 @@ def get_endpoint_resource_arguments_from_request( bundle_name=model_bundle.name, post_inference_hooks=build_endpoint_request.post_inference_hooks, user_id=user_id, + billing_queue=hmi_config.billing_queue_arn, + billing_tags=build_endpoint_request.billing_tags, default_callback_url=build_endpoint_request.default_callback_url, default_callback_auth=build_endpoint_request.default_callback_auth, + endpoint_id=model_endpoint_record.id, + endpoint_type=model_endpoint_record.endpoint_type, + bundle_id=model_bundle.id, + labels=build_endpoint_request.labels, ).serialize() return EndpointConfigArguments( # Base resource arguments @@ -1253,6 +1226,7 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, ENDPOINT_CONFIG_SERIALIZED=endpoint_config_serialized, ) elif endpoint_resource_name == "horizontal-pod-autoscaler": @@ -1269,12 +1243,39 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, API_VERSION=api_version, # Autoscaler arguments CONCURRENCY=concurrency, MIN_WORKERS=build_endpoint_request.min_workers, MAX_WORKERS=build_endpoint_request.max_workers, ) + elif endpoint_resource_name == "keda-scaled-object": + concurrency = get_target_concurrency_from_per_worker_value( + build_endpoint_request.per_worker + ) + return KedaScaledObjectArguments( + # Base resource arguments + RESOURCE_NAME=k8s_resource_group_name, + NAMESPACE=hmi_config.endpoint_namespace, + ENDPOINT_ID=model_endpoint_record.id, + ENDPOINT_NAME=model_endpoint_record.name, + TEAM=team, + PRODUCT=product, + CREATED_BY=created_by, + OWNER=owner, + GIT_TAG=GIT_TAG, + # Scaled Object arguments + MIN_WORKERS=build_endpoint_request.min_workers, + MAX_WORKERS=build_endpoint_request.max_workers, + CONCURRENCY=concurrency, + REDIS_HOST_PORT=hmi_config.cache_redis_host_port, + REDIS_DB_INDEX=str(hmi_config.cache_redis_db_index), + SERVICEBUS_NAMESPACE=os.getenv("SERVICEBUS_NAMESPACE"), + AUTHENTICATION_REF="azure-workload-identity", + PROMETHEUS_SERVER_ADDRESS=infra_config().prometheus_server_address + or "dummy-value", # We should never get to "dummy-value", validation should have taken place to ensure prom_server_addr is not None. + ) elif endpoint_resource_name == "service": # Use ClusterIP by default for sync endpoint. # In Circle CI, we use a NodePort to expose the service to CI. @@ -1294,10 +1295,80 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, + # Service arguments + NODE_PORT_DICT=node_port_dict, + SERVICE_TYPE=service_type, + SERVICE_TARGET_PORT=FORWARDER_PORT, + ) + elif endpoint_resource_name == "lws-service": + # Use ClusterIP by default for sync endpoint. + # In Circle CI, we use a NodePort to expose the service to CI. + service_type = "ClusterIP" if not CIRCLECI else "NodePort" + if service_type == "NodePort": + node_port = get_node_port(k8s_resource_group_name) + node_port_dict = DictStrInt(f"nodePort: {node_port}") + else: + node_port_dict = DictStrInt("") + return LwsServiceArguments( + # Base resource arguments + RESOURCE_NAME=k8s_resource_group_name, + NAMESPACE=hmi_config.endpoint_namespace, + ENDPOINT_ID=model_endpoint_record.id, + ENDPOINT_NAME=model_endpoint_record.name, + TEAM=team, + PRODUCT=product, + CREATED_BY=created_by, + OWNER=owner, + GIT_TAG=GIT_TAG, # Service arguments NODE_PORT_DICT=node_port_dict, SERVICE_TYPE=service_type, SERVICE_TARGET_PORT=FORWARDER_PORT, + # LWS Service args + SERVICE_NAME_OVERRIDE=service_name_override, + ) + elif endpoint_resource_name == "virtual-service": + return VirtualServiceArguments( + # Base resource arguments + RESOURCE_NAME=k8s_resource_group_name, + NAMESPACE=hmi_config.endpoint_namespace, + ENDPOINT_ID=model_endpoint_record.id, + ENDPOINT_NAME=model_endpoint_record.name, + TEAM=team, + PRODUCT=product, + CREATED_BY=created_by, + OWNER=owner, + GIT_TAG=GIT_TAG, + DNS_HOST_DOMAIN=infra_config().dns_host_domain, + ) + elif endpoint_resource_name == "destination-rule": + return DestinationRuleArguments( + # Base resource arguments + RESOURCE_NAME=k8s_resource_group_name, + NAMESPACE=hmi_config.endpoint_namespace, + ENDPOINT_ID=model_endpoint_record.id, + ENDPOINT_NAME=model_endpoint_record.name, + TEAM=team, + PRODUCT=product, + CREATED_BY=created_by, + OWNER=owner, + GIT_TAG=GIT_TAG, + ) + elif endpoint_resource_name == "lws-service-entry": + return LwsServiceEntryArguments( + # Base resource arguments + RESOURCE_NAME=k8s_resource_group_name, + NAMESPACE=hmi_config.endpoint_namespace, + ENDPOINT_ID=model_endpoint_record.id, + ENDPOINT_NAME=model_endpoint_record.name, + TEAM=team, + PRODUCT=product, + CREATED_BY=created_by, + OWNER=owner, + GIT_TAG=GIT_TAG, + # LWS Service Entry args + SERVICE_NAME_OVERRIDE=service_name_override, ) elif endpoint_resource_name == "vertical-pod-autoscaler": return VerticalPodAutoscalerArguments( @@ -1310,8 +1381,22 @@ def get_endpoint_resource_arguments_from_request( PRODUCT=product, CREATED_BY=created_by, OWNER=owner, + GIT_TAG=GIT_TAG, CPUS=str(build_endpoint_request.cpus), MEMORY=str(build_endpoint_request.memory), ) + elif endpoint_resource_name == "pod-disruption-budget": + return PodDisruptionBudgetArguments( + # Base resource arguments + RESOURCE_NAME=k8s_resource_group_name, + NAMESPACE=hmi_config.endpoint_namespace, + ENDPOINT_ID=model_endpoint_record.id, + ENDPOINT_NAME=model_endpoint_record.name, + TEAM=team, + PRODUCT=product, + CREATED_BY=created_by, + OWNER=owner, + GIT_TAG=GIT_TAG, + ) else: raise Exception(f"Unknown resource name: {endpoint_resource_name}") diff --git a/server/llm_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py similarity index 53% rename from server/llm_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py rename to model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py index 33880fb7..d884ab17 100644 --- a/server/llm_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py @@ -1,60 +1,52 @@ from typing import Dict, Optional, Tuple -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.entities import ( ModelEndpointInfraState, ModelEndpointRecord, ModelEndpointType, ) -from llm_engine_server.domain.exceptions import EndpointResourceInvalidRequestException -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.domain.gateways import InferenceAutoscalingMetricsGateway +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, EndpointResourceGatewayCreateOrUpdateResourcesResponse, - QueueInfo, ) -from llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( K8SEndpointResourceDelegate, ) -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, + QueueInfo, ) -logger = make_logger(filename_wo_ext(__file__)) - - -class SqsQueueInfo(QueueInfo): - """Live endpoints create and use SQS queues. These come with an additional per-queue URL. - - NOTE: broker for this class **MUST** always be SQS. - """ - - queue_url: str +logger = make_logger(logger_name()) - @staticmethod - def new(queue_name: str, queue_url: str) -> "SqsQueueInfo": - return SqsQueueInfo(queue_name=queue_name, broker=BrokerType.SQS, queue_url=queue_url) - -class LiveEndpointResourceGateway(EndpointResourceGateway[SqsQueueInfo]): - def __init__(self, sqs_delegate: SQSEndpointResourceDelegate): +class LiveEndpointResourceGateway(EndpointResourceGateway[QueueInfo]): + def __init__( + self, + queue_delegate: QueueEndpointResourceDelegate, + inference_autoscaling_metrics_gateway: Optional[InferenceAutoscalingMetricsGateway], + ): self.k8s_delegate = K8SEndpointResourceDelegate() - self.sqs_delegate = sqs_delegate + self.queue_delegate = queue_delegate + self.inference_autoscaling_metrics_gateway = inference_autoscaling_metrics_gateway async def create_queue( self, endpoint_record: ModelEndpointRecord, labels: Dict[str, str], - ) -> SqsQueueInfo: - """Creates a new SQS queue, returning its unique name and queue URL.""" - queue_name, queue_url = await self.sqs_delegate.create_queue_if_not_exists( + ) -> QueueInfo: + """Creates a new queue, returning its unique name and queue URL.""" + queue_name, queue_url = await self.queue_delegate.create_queue_if_not_exists( endpoint_id=endpoint_record.id, endpoint_name=endpoint_record.name, endpoint_created_by=endpoint_record.created_by, endpoint_labels=labels, ) - return SqsQueueInfo.new(queue_name, queue_url) + return QueueInfo(queue_name, queue_url) async def create_or_update_resources( self, request: CreateOrUpdateResourcesRequest @@ -67,13 +59,16 @@ async def create_or_update_resources( q = await self.create_queue(endpoint_record, request.build_endpoint_request.labels) queue_name: Optional[str] = q.queue_name queue_url: Optional[str] = q.queue_url - destination: str = q.queue_name else: - destination = f"llm-engine-endpoint-id-{endpoint_record.id.replace('_', '-')}" queue_name = None queue_url = None - await self.k8s_delegate.create_or_update_resources( + if self.inference_autoscaling_metrics_gateway is not None: + await self.inference_autoscaling_metrics_gateway.create_or_update_resources( + endpoint_record.id + ) + + destination: str = await self.k8s_delegate.create_or_update_resources( request=request, sqs_queue_name=queue_name, sqs_queue_url=queue_url, @@ -90,11 +85,16 @@ async def get_resources( ) if endpoint_type == ModelEndpointType.ASYNC: - sqs_attributes = await self.sqs_delegate.get_queue_attributes(endpoint_id=endpoint_id) - if "ApproximateNumberOfMessages" in sqs_attributes["Attributes"]: + sqs_attributes = await self.queue_delegate.get_queue_attributes(endpoint_id=endpoint_id) + if ( + "Attributes" in sqs_attributes + and "ApproximateNumberOfMessages" in sqs_attributes["Attributes"] + ): resources.num_queued_items = int( sqs_attributes["Attributes"]["ApproximateNumberOfMessages"] ) + elif "active_message_count" in sqs_attributes: # from ASBQueueEndpointResourceDelegate + resources.num_queued_items = int(sqs_attributes["active_message_count"]) return resources @@ -113,9 +113,12 @@ async def delete_resources( ) sqs_result = True try: - await self.sqs_delegate.delete_queue(endpoint_id=endpoint_id) - except EndpointResourceInvalidRequestException as e: + await self.queue_delegate.delete_queue(endpoint_id=endpoint_id) + except EndpointResourceInfraException as e: logger.warning("Could not delete SQS resources", exc_info=e) sqs_result = False + if self.inference_autoscaling_metrics_gateway is not None: + await self.inference_autoscaling_metrics_gateway.delete_resources(endpoint_id) + return k8s_result and sqs_result diff --git a/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py new file mode 100644 index 00000000..76c77e64 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/resources/queue_endpoint_resource_delegate.py @@ -0,0 +1,46 @@ +from abc import ABC, abstractmethod +from typing import Any, Dict, NamedTuple, Optional, Sequence + +__all__: Sequence[str] = ( + "QueueInfo", + "QueueEndpointResourceDelegate", +) + + +class QueueInfo(NamedTuple): + queue_name: str + queue_url: Optional[str] + + +class QueueEndpointResourceDelegate(ABC): + """ + Base class for an interactor with SQS or ASB. This is used by the LiveEndpointResourceGateway. + """ + + @abstractmethod + async def create_queue_if_not_exists( + self, + endpoint_id: str, + endpoint_name: str, + endpoint_created_by: str, + endpoint_labels: Dict[str, Any], + ) -> QueueInfo: + """ + Creates a queue associated with the given endpoint_id. Other fields are set as tags on the queue. + """ + + @abstractmethod + async def delete_queue(self, endpoint_id: str) -> None: + """ + Deletes a queue associated with the given endpoint_id. This is a no-op if the queue does not exist. + """ + + @abstractmethod + async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]: + """ + Get attributes of a queue. + """ + + @staticmethod + def endpoint_id_to_queue_name(endpoint_id: str) -> str: + return f"launch-endpoint-id-{endpoint_id}" diff --git a/server/llm_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py b/model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py similarity index 80% rename from server/llm_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py rename to model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py index 87dac35a..748c3f69 100644 --- a/server/llm_engine_server/infra/gateways/resources/live_sqs_endpoint_resource_delegate.py +++ b/model-engine/model_engine_server/infra/gateways/resources/sqs_queue_endpoint_resource_delegate.py @@ -5,23 +5,25 @@ import botocore.exceptions from aioboto3 import Session as AioSession from aiobotocore.client import AioBaseClient -from llm_engine_server.common.config import hmi_config -from llm_engine_server.core.aws.roles import session -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, - SQSQueueInfo, +from model_engine_server.common.config import hmi_config +from model_engine_server.core.aws.roles import session +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, + QueueInfo, ) -from mypy_boto3_sqs.type_defs import GetQueueAttributesResultTypeDef -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) -__all__: Sequence[str] = ("LiveSQSEndpointResourceDelegate",) +__all__: Sequence[str] = ("SQSQueueEndpointResourceDelegate",) def _create_async_sqs_client(sqs_profile: Optional[str]) -> AioBaseClient: - return session(role=sqs_profile, session_type=AioSession).client("sqs", region_name="us-west-2") + return session(role=sqs_profile, session_type=AioSession).client( + "sqs", region_name=infra_config().default_region + ) def _get_queue_policy(queue_name: str) -> str: @@ -43,7 +45,7 @@ def _get_queue_tags( ) -class LiveSQSEndpointResourceDelegate(SQSEndpointResourceDelegate): +class SQSQueueEndpointResourceDelegate(QueueEndpointResourceDelegate): def __init__(self, sqs_profile: Optional[str]): self.sqs_profile = sqs_profile @@ -53,13 +55,13 @@ async def create_queue_if_not_exists( endpoint_name: str, endpoint_created_by: str, endpoint_labels: Dict[str, Any], - ) -> SQSQueueInfo: + ) -> QueueInfo: async with _create_async_sqs_client(sqs_profile=self.sqs_profile) as sqs_client: - queue_name = SQSEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) try: get_queue_url_response = await sqs_client.get_queue_url(QueueName=queue_name) - return SQSQueueInfo( + return QueueInfo( queue_name=queue_name, queue_url=get_queue_url_response["QueueUrl"], ) @@ -91,10 +93,10 @@ async def create_queue_if_not_exists( f"Creating SQS queue got non-200 response: {create_response}" ) - return SQSQueueInfo(queue_name, create_response["QueueUrl"]) + return QueueInfo(queue_name, create_response["QueueUrl"]) async def delete_queue(self, endpoint_id: str) -> None: - queue_name = SQSEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) async with _create_async_sqs_client(self.sqs_profile) as sqs_client: try: queue_url = (await sqs_client.get_queue_url(QueueName=queue_name))["QueueUrl"] @@ -119,8 +121,8 @@ async def delete_queue(self, endpoint_id: str) -> None: f"Deleting SQS queue got non-200 response: {delete_response}" ) - async def get_queue_attributes(self, endpoint_id: str) -> GetQueueAttributesResultTypeDef: - queue_name = SQSEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) + async def get_queue_attributes(self, endpoint_id: str) -> Dict[str, Any]: + queue_name = QueueEndpointResourceDelegate.endpoint_id_to_queue_name(endpoint_id) async with _create_async_sqs_client(self.sqs_profile) as sqs_client: try: queue_url = (await sqs_client.get_queue_url(QueueName=queue_name))["QueueUrl"] diff --git a/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml similarity index 65% rename from server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml rename to model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml index 41ffe75b..83e0fa0d 100644 --- a/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml +++ b/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml @@ -1,18 +1,19 @@ --- -# Source: llm-engine/templates/service_template_config_map.yaml +# Source: model-engine/templates/service_template_config_map.yaml # THIS FILE IS AUTOGENERATED USING `just autogen-templates`. PLEASE EDIT THE GOTEMPLATE FILE IN THE HELM CHART!!! apiVersion: v1 kind: ConfigMap metadata: - name: llm-engine-service-template-config + name: model-engine-service-template-config labels: team: infra - product: llm-engine - helm.sh/chart: llm-engine-0.1.0 - app.kubernetes.io/managed-by: Helm - app.kubernetes.io/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + app.kubernetes.io/version: 88f8003b2b52c772e8f34d264b3dfb95da1c1e9b + tags.datadoghq.com/version: 88f8003b2b52c772e8f34d264b3dfb95da1c1e9b tags.datadoghq.com/env: circleci + env: circleci + product: model-engine + helm.sh/chart: model-engine-0.1.3 + app.kubernetes.io/managed-by: Helm annotations: "helm.sh/hook": pre-install,pre-upgrade "helm.sh/hook-weight": "-2" @@ -30,10 +31,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -65,10 +66,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -99,53 +100,58 @@ data: values: - "True" topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 + terminationGracePeriodSeconds: 1800 serviceAccount: default - nodeSelector: - node-lifecycle: normal priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.celery_forwarder - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility - "VISIBILITY_24H" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.async.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --concurrency + - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --num-workers - "${PER_WORKER}" + - --broker-type + - redis env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: CELERY_QUEUE value: "${QUEUE}" - name: CELERY_TASK_VISIBILITY @@ -165,7 +171,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -174,9 +180,9 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs - name: tritonserver - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv:${TRITON_COMMIT_TAG}-triton + image: nvidia/tritonserver:${TRITON_COMMIT_TAG}-triton imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -187,6 +193,8 @@ data: env: - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" ports: - containerPort: 8000 name: http @@ -215,7 +223,7 @@ data: ${TRITON_STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -234,6 +242,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: cpu: ${CPUS} @@ -245,16 +254,12 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -270,7 +275,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -282,7 +287,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -299,10 +304,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -334,10 +339,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -368,53 +373,58 @@ data: values: - "True" topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 + terminationGracePeriodSeconds: 1800 serviceAccount: default - nodeSelector: - node-lifecycle: normal priorityClassName: ${PRIORITY} containers: - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.celery_forwarder - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility - "VISIBILITY_24H" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.async.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --concurrency + - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --num-workers - "${PER_WORKER}" + - --broker-type + - redis env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: CELERY_QUEUE value: "${QUEUE}" - name: CELERY_TASK_VISIBILITY @@ -434,7 +444,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -443,7 +453,7 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs - name: main securityContext: capabilities: @@ -459,6 +469,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: cpu: ${CPUS} @@ -470,16 +481,12 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -495,7 +502,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -507,11 +514,11 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml - deployment-artifact-async-cpu.yaml: |- + deployment-triton-enhanced-runnable-image-sync-cpu.yaml: |- apiVersion: apps/v1 kind: Deployment metadata: @@ -524,20 +531,13 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} - annotations: - celery.scaleml.autoscaler/queue: ${QUEUE} - celery.scaleml.autoscaler/broker: ${BROKER_NAME} - celery.scaleml.autoscaler/taskVisibility: "VISIBILITY_24H" - celery.scaleml.autoscaler/perWorker: "${PER_WORKER}" - celery.scaleml.autoscaler/minWorkers: "${MIN_WORKERS}" - celery.scaleml.autoscaler/maxWorkers: "${MAX_WORKERS}" spec: strategy: type: RollingUpdate @@ -559,14 +559,13 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} - sidecar.istio.io/inject: "false" # TODO: switch to scuttle version: v1 annotations: ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' @@ -595,81 +594,151 @@ data: topologyKey: kubernetes.io/hostname terminationGracePeriodSeconds: 600 serviceAccount: default - nodeSelector: - node-lifecycle: normal priorityClassName: ${PRIORITY} containers: - - image: ${IMAGE} + - name: http-forwarder + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent - name: main - securityContext: - capabilities: - drop: - - all + command: + - /usr/bin/dumb-init + - -- + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder + - --config + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - --port + - "${FORWARDER_PORT}" + - --num-workers + - "${FORWARDER_WORKER_COUNT}" + - --set + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - - name: OMP_NUM_THREADS - value: "1" - - name: BASE_PATH - value: "${BASE_PATH}" - - name: BUNDLE_URL - value: "${BUNDLE_URL}" - - name: LOAD_PREDICT_FN_MODULE_PATH - value: "${LOAD_PREDICT_FN_MODULE_PATH}" - - name: LOAD_MODEL_FN_MODULE_PATH - value: "${LOAD_MODEL_FN_MODULE_PATH}" - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - - name: CHILD_FN_INFO - value: "${CHILD_FN_INFO}" - - name: PREWARM - value: "${PREWARM}" + - name: BASE_PATH + value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" - - name: CELERY_S3_BUCKET - value: "${CELERY_S3_BUCKET}" - - name: BROKER_TYPE - value: "${BROKER_TYPE}" - - name: SQS_PROFILE - value: "${SQS_PROFILE}" - - name: SQS_QUEUE_NAME - value: "${QUEUE}" - - name: SQS_QUEUE_URL - value: "${SQS_QUEUE_URL}" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" + - name: HTTP_HOST + value: "0.0.0.0" readinessProbe: - exec: - command: - - cat - - /tmp/readyz - initialDelaySeconds: 2 - periodSeconds: 2 - failureThreshold: 100 - command: [ "dumb-init", "--", "ddtrace-run" ] - # Not including --pool=solo means there's a worker process and a separate supervisor process - # meaning if the worker crashes (because of OOM or something) the supervisor process can mark the task as - # failed, which should get rid of infinite task retries - args: - - celery - - --app=llm_engine.inference.async_inference - - worker - - --loglevel=INFO - - --concurrency=1 - - --queues=${QUEUE} - - -O - - fair + httpGet: + path: /readyz + port: ${FORWARDER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: "100M" + ephemeral-storage: "100M" + limits: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: ${FORWARDER_MEMORY_LIMIT} + ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} + + + volumeMounts: + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + - name: user-config + mountPath: /workspace/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /workspace/endpoint_config + subPath: raw_data + - name: infra-service-config-volume + mountPath: /workspace/model-engine/model_engine_server/core/configs + ports: + - containerPort: ${FORWARDER_PORT} + name: http + - name: tritonserver + image: nvidia/tritonserver:${TRITON_COMMIT_TAG}-triton + imagePullPolicy: IfNotPresent + command: + - /usr/bin/dumb-init + - -- + - bash + - -c + - "$TRITON_COMMAND" + env: + - name: AWS_PROFILE + value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" + ports: + - containerPort: 8000 + name: http + - containerPort: 8001 + name: grpc + - containerPort: 8002 + name: metrics + readinessProbe: + httpGet: + # Need to have Triton support --http-address IPv6 :( + # https://github:com/triton-inference-server/server/issues/5305: + # path: /v2/health/ready + # port: 8000 + path: /readyz + port: 3000 + initialDelaySeconds: $TRITON_READINESS_INITIAL_DELAY + periodSeconds: 10 + resources: + requests: + cpu: ${TRITON_CPUS} + ${TRITON_MEMORY_DICT} + ${TRITON_STORAGE_DICT} + limits: + cpu: ${TRITON_CPUS} + ${TRITON_MEMORY_DICT} + ${TRITON_STORAGE_DICT} + volumeMounts: + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + - mountPath: /dev/shm + name: dshm + - name: main + securityContext: + capabilities: + drop: + - all + image: ${IMAGE} + imagePullPolicy: IfNotPresent + command: ${COMMAND} + env: ${MAIN_ENV} + readinessProbe: + httpGet: + path: ${HEALTHCHECK_ROUTE} + port: ${USER_CONTAINER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 resources: requests: cpu: ${CPUS} @@ -681,26 +750,28 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /opt/.aws/config subPath: config + - mountPath: /dev/shm + name: dshm + - name: infra-service-config-volume + mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - name: user-config - mountPath: ${BASE_PATH}/user_config + mountPath: /app/user_config subPath: raw_data - name: endpoint-config - mountPath: ${BASE_PATH}/endpoint_config + mountPath: /app/endpoint_config subPath: raw_data - - name: infra-service-config-volume - mountPath: ${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs + ports: + - containerPort: ${USER_CONTAINER_PORT} + name: http # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 securityContext: fsGroup: 65534 volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -712,11 +783,11 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml - deployment-triton-enhanced-runnable-image-sync-cpu.yaml: |- + deployment-runnable-image-sync-cpu.yaml: |- apiVersion: apps/v1 kind: Deployment metadata: @@ -729,10 +800,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -757,10 +828,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -792,51 +863,54 @@ data: topologyKey: kubernetes.io/hostname terminationGracePeriodSeconds: 600 serviceAccount: default - nodeSelector: - node-lifecycle: normal priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - - --http - - production_threads + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --port - "${FORWARDER_PORT}" - - --concurrency - - "${PER_WORKER}" + - --num-workers + - "${FORWARDER_WORKER_COUNT}" + - --set + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: HTTP_HOST value: "0.0.0.0" readinessProbe: @@ -845,9 +919,10 @@ data: port: ${FORWARDER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: - cpu: 0.1 + cpu: ${FORWARDER_CPUS_LIMIT} memory: "100M" ephemeral-storage: "100M" limits: @@ -858,7 +933,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -867,54 +942,10 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs ports: - containerPort: ${FORWARDER_PORT} name: http - - name: tritonserver - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv:${TRITON_COMMIT_TAG}-triton - imagePullPolicy: IfNotPresent - command: - - /usr/bin/dumb-init - - -- - - bash - - -c - - "$TRITON_COMMAND" - env: - - name: AWS_PROFILE - value: "${AWS_ROLE}" - ports: - - containerPort: 8000 - name: http - - containerPort: 8001 - name: grpc - - containerPort: 8002 - name: metrics - readinessProbe: - httpGet: - # Need to have Triton support --http-address IPv6 :( - # https://github:com/triton-inference-server/server/issues/5305: - # path: /v2/health/ready - # port: 8000 - path: /readyz - port: 3000 - initialDelaySeconds: $TRITON_READINESS_INITIAL_DELAY - periodSeconds: 10 - resources: - requests: - cpu: ${TRITON_CPUS} - ${TRITON_MEMORY_DICT} - ${TRITON_STORAGE_DICT} - limits: - cpu: ${TRITON_CPUS} - ${TRITON_MEMORY_DICT} - ${TRITON_STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - mountPath: /dev/shm - name: dshm - name: main securityContext: capabilities: @@ -930,6 +961,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: cpu: ${CPUS} @@ -941,16 +973,12 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -966,7 +994,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -978,11 +1006,11 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml - deployment-runnable-image-sync-cpu.yaml: |- + deployment-runnable-image-streaming-cpu.yaml: |- apiVersion: apps/v1 kind: Deployment metadata: @@ -995,10 +1023,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -1023,10 +1051,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -1058,51 +1086,60 @@ data: topologyKey: kubernetes.io/hostname terminationGracePeriodSeconds: 600 serviceAccount: default - nodeSelector: - node-lifecycle: normal priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - - --http - - production_threads + - /workspace/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml - --port - "${FORWARDER_PORT}" - - --concurrency - - "${PER_WORKER}" + - --num-workers + - "${FORWARDER_WORKER_COUNT}" + - --set + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.stream.predict_route=${STREAMING_PREDICT_ROUTE}" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: HTTP_HOST value: "0.0.0.0" readinessProbe: @@ -1111,9 +1148,10 @@ data: port: ${FORWARDER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: - cpu: 0.1 + cpu: ${FORWARDER_CPUS_LIMIT} memory: "100M" ephemeral-storage: "100M" limits: @@ -1124,7 +1162,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -1133,7 +1171,7 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs ports: - containerPort: ${FORWARDER_PORT} name: http @@ -1152,6 +1190,7 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: cpu: ${CPUS} @@ -1163,16 +1202,12 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -1188,7 +1223,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -1200,11 +1235,11 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml - deployment-artifact-sync-cpu.yaml: |- + deployment-triton-enhanced-runnable-image-async-gpu.yaml: |- apiVersion: apps/v1 kind: Deployment metadata: @@ -1217,13 +1252,20 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} + annotations: + celery.scaleml.autoscaler/queue: ${QUEUE} + celery.scaleml.autoscaler/broker: ${BROKER_NAME} + celery.scaleml.autoscaler/taskVisibility: "VISIBILITY_24H" + celery.scaleml.autoscaler/perWorker: "${PER_WORKER}" + celery.scaleml.autoscaler/minWorkers: "${MIN_WORKERS}" + celery.scaleml.autoscaler/maxWorkers: "${MAX_WORKERS}" spec: strategy: type: RollingUpdate @@ -1245,13 +1287,14 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} + sidecar.istio.io/inject: "false" # TODO: switch to scuttle version: v1 annotations: ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' @@ -1278,472 +1321,64 @@ data: values: - "True" topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 + terminationGracePeriodSeconds: 1800 serviceAccount: default nodeSelector: - node-lifecycle: normal + k8s.amazonaws.com/accelerator: ${GPU_TYPE} + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" priorityClassName: ${PRIORITY} containers: - - image: ${IMAGE} - imagePullPolicy: IfNotPresent - name: main - securityContext: - capabilities: - drop: - - all - env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" - - name: DD_SERVICE - value: "${ENDPOINT_NAME}" - - name: DD_ENV - value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - - name: DD_AGENT_HOST - valueFrom: - fieldRef: - fieldPath: status.hostIP - - name: OMP_NUM_THREADS - value: "1" - - name: BASE_PATH - value: "${BASE_PATH}" - - name: BUNDLE_URL - value: "${BUNDLE_URL}" - - name: LOAD_PREDICT_FN_MODULE_PATH - value: "${LOAD_PREDICT_FN_MODULE_PATH}" - - name: LOAD_MODEL_FN_MODULE_PATH - value: "${LOAD_MODEL_FN_MODULE_PATH}" - - name: AWS_PROFILE - value: "${AWS_ROLE}" - - name: RESULTS_S3_BUCKET - value: "${RESULTS_S3_BUCKET}" - - name: CHILD_FN_INFO - value: "${CHILD_FN_INFO}" - - name: PREWARM - value: "${PREWARM}" - - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" - - name: PORT - value: "${ARTIFACT_LIKE_CONTAINER_PORT}" - readinessProbe: - httpGet: - path: /readyz - port: ${ARTIFACT_LIKE_CONTAINER_PORT} - initialDelaySeconds: 2 - periodSeconds: 2 - failureThreshold: 100 - command: [ "dumb-init", "--", "ddtrace-run" ] - args: - - python - - -m - - llm_engine.inference.sync_inference.start_fastapi_server - resources: - requests: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - limits: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - - name: user-config - mountPath: ${BASE_PATH}/user_config - subPath: raw_data - - name: endpoint-config - mountPath: ${BASE_PATH}/endpoint_config - subPath: raw_data - - name: infra-service-config-volume - mountPath: ${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs - # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 - securityContext: - fsGroup: 65534 - volumes: - - name: config-volume - configMap: - name: default-config - - name: user-config - configMap: - name: ${RESOURCE_NAME} - - name: endpoint-config - configMap: - name: ${RESOURCE_NAME}-endpoint-config - - name: dshm - emptyDir: - medium: Memory - - name: infra-service-config-volume - configMap: - name: llm-engine-service-config - items: - - key: infra_service_config - path: config.yaml - deployment-runnable-image-streaming-cpu.yaml: |- - apiVersion: apps/v1 - kind: Deployment - metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} - labels: - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - spec: - strategy: - type: RollingUpdate - rollingUpdate: - maxSurge: 1 - maxUnavailable: 0 - replicas: ${MIN_WORKERS} - selector: - matchLabels: - app: ${RESOURCE_NAME} - version: v1 - template: - metadata: - labels: - app: ${RESOURCE_NAME} - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - version: v1 - annotations: - ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' - kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" - spec: - affinity: - podAffinity: - preferredDuringSchedulingIgnoredDuringExecution: - - weight: 1 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: app - operator: In - values: - - ${RESOURCE_NAME} - topologyKey: kubernetes.io/hostname - - weight: 100 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: ${IMAGE_HASH} - operator: In - values: - - "True" - topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 - serviceAccount: default - nodeSelector: - node-lifecycle: normal - priorityClassName: ${PRIORITY} - containers: - - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} - imagePullPolicy: IfNotPresent - command: - - /usr/bin/dumb-init - - -- - - ddtrace-run - - python - - -m - - llm_engine.inference.forwarding.http_forwarder - - --config - - /workspace/llm_engine/llm_engine/inference/configs/service--http_forwarder.yaml - - --port - - "${FORWARDER_PORT}" - - --num-workers - - "${PER_WORKER}" - - --set - - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - - --set - - "forwarder.stream.predict_route=${STREAMING_PREDICT_ROUTE}" - - --set - - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --set - - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" - env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" - - name: DD_SERVICE - value: "${ENDPOINT_NAME}" - - name: DD_ENV - value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - - name: DD_AGENT_HOST - valueFrom: - fieldRef: - fieldPath: status.hostIP - - name: AWS_PROFILE - value: "${AWS_ROLE}" - - name: RESULTS_S3_BUCKET - value: "${RESULTS_S3_BUCKET}" - - name: BASE_PATH - value: "/workspace" - - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" - - name: HTTP_HOST - value: "0.0.0.0" - readinessProbe: - httpGet: - path: /readyz - port: ${FORWARDER_PORT} - initialDelaySeconds: ${READINESS_INITIAL_DELAY} - periodSeconds: 5 - resources: - requests: - cpu: 0.1 - memory: "100M" - ephemeral-storage: "100M" - limits: - cpu: ${FORWARDER_CPUS_LIMIT} - memory: ${FORWARDER_MEMORY_LIMIT} - ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} - - - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: user-config - mountPath: /workspace/user_config - subPath: raw_data - - name: endpoint-config - mountPath: /workspace/endpoint_config - subPath: raw_data - - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs - ports: - - containerPort: ${FORWARDER_PORT} - name: http - - name: main - securityContext: - capabilities: - drop: - - all - image: ${IMAGE} - imagePullPolicy: IfNotPresent - command: ${COMMAND} - env: ${MAIN_ENV} - readinessProbe: - httpGet: - path: ${HEALTHCHECK_ROUTE} - port: ${USER_CONTAINER_PORT} - initialDelaySeconds: ${READINESS_INITIAL_DELAY} - periodSeconds: 5 - resources: - requests: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - limits: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - mountPath: /dev/shm - name: dshm - - name: infra-service-config-volume - mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - - name: user-config - mountPath: /app/user_config - subPath: raw_data - - name: endpoint-config - mountPath: /app/endpoint_config - subPath: raw_data - ports: - - containerPort: ${USER_CONTAINER_PORT} - name: http - # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 - securityContext: - fsGroup: 65534 - volumes: - - name: config-volume - configMap: - name: default-config - - name: user-config - configMap: - name: ${RESOURCE_NAME} - - name: endpoint-config - configMap: - name: ${RESOURCE_NAME}-endpoint-config - - name: dshm - emptyDir: - medium: Memory - - name: infra-service-config-volume - configMap: - name: llm-engine-service-config - items: - - key: infra_service_config - path: config.yaml - deployment-triton-enhanced-runnable-image-async-gpu.yaml: |- - apiVersion: apps/v1 - kind: Deployment - metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} - labels: - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - annotations: - celery.scaleml.autoscaler/queue: ${QUEUE} - celery.scaleml.autoscaler/broker: ${BROKER_NAME} - celery.scaleml.autoscaler/taskVisibility: "VISIBILITY_24H" - celery.scaleml.autoscaler/perWorker: "${PER_WORKER}" - celery.scaleml.autoscaler/minWorkers: "${MIN_WORKERS}" - celery.scaleml.autoscaler/maxWorkers: "${MAX_WORKERS}" - spec: - strategy: - type: RollingUpdate - rollingUpdate: - maxSurge: 1 - maxUnavailable: 0 - replicas: ${MIN_WORKERS} - selector: - matchLabels: - app: ${RESOURCE_NAME} - version: v1 - template: - metadata: - labels: - app: ${RESOURCE_NAME} - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - sidecar.istio.io/inject: "false" # TODO: switch to scuttle - version: v1 - annotations: - ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' - kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" - spec: - affinity: - podAffinity: - preferredDuringSchedulingIgnoredDuringExecution: - - weight: 1 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: app - operator: In - values: - - ${RESOURCE_NAME} - topologyKey: kubernetes.io/hostname - - weight: 100 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: ${IMAGE_HASH} - operator: In - values: - - "True" - topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 - serviceAccount: default - nodeSelector: - node-lifecycle: normal - k8s.amazonaws.com/accelerator: ${GPU_TYPE} - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - priorityClassName: ${PRIORITY} - containers: - - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + - name: celery-forwarder + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.celery_forwarder - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --queue - "${QUEUE}" - --task-visibility - "VISIBILITY_24H" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.async.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --concurrency + - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --num-workers - "${PER_WORKER}" + - --broker-type + - redis env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: CELERY_QUEUE value: "${QUEUE}" - name: CELERY_TASK_VISIBILITY @@ -1763,7 +1398,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -1772,9 +1407,9 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs - name: tritonserver - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv:${TRITON_COMMIT_TAG}-triton + image: nvidia/tritonserver:${TRITON_COMMIT_TAG}-triton imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -1785,6 +1420,8 @@ data: env: - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" ports: - containerPort: 8000 name: http @@ -1813,241 +1450,10 @@ data: ${TRITON_STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config - subPath: config - - mountPath: /dev/shm - name: dshm - - name: main - securityContext: - capabilities: - drop: - - all - image: ${IMAGE} - imagePullPolicy: IfNotPresent - command: ${COMMAND} - env: ${MAIN_ENV} - readinessProbe: - httpGet: - path: ${HEALTHCHECK_ROUTE} - port: ${USER_CONTAINER_PORT} - initialDelaySeconds: ${READINESS_INITIAL_DELAY} - periodSeconds: 5 - resources: - requests: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - limits: - nvidia.com/gpu: ${GPUS} - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm - name: dshm - - name: infra-service-config-volume - mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - - name: user-config - mountPath: /app/user_config - subPath: raw_data - - name: endpoint-config - mountPath: /app/endpoint_config - subPath: raw_data - ports: - - containerPort: ${USER_CONTAINER_PORT} - name: http - # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 - securityContext: - fsGroup: 65534 - volumes: - - name: config-volume - configMap: - name: default-config - - name: user-config - configMap: - name: ${RESOURCE_NAME} - - name: endpoint-config - configMap: - name: ${RESOURCE_NAME}-endpoint-config - - name: dshm - emptyDir: - medium: Memory - - name: infra-service-config-volume - configMap: - name: llm-engine-service-config - items: - - key: infra_service_config - path: config.yaml - deployment-runnable-image-async-gpu.yaml: |- - apiVersion: apps/v1 - kind: Deployment - metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} - labels: - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - annotations: - celery.scaleml.autoscaler/queue: ${QUEUE} - celery.scaleml.autoscaler/broker: ${BROKER_NAME} - celery.scaleml.autoscaler/taskVisibility: "VISIBILITY_24H" - celery.scaleml.autoscaler/perWorker: "${PER_WORKER}" - celery.scaleml.autoscaler/minWorkers: "${MIN_WORKERS}" - celery.scaleml.autoscaler/maxWorkers: "${MAX_WORKERS}" - spec: - strategy: - type: RollingUpdate - rollingUpdate: - maxSurge: 1 - maxUnavailable: 0 - replicas: ${MIN_WORKERS} - selector: - matchLabels: - app: ${RESOURCE_NAME} - version: v1 - template: - metadata: - labels: - app: ${RESOURCE_NAME} - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - sidecar.istio.io/inject: "false" # TODO: switch to scuttle - version: v1 - annotations: - ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' - kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" - spec: - affinity: - podAffinity: - preferredDuringSchedulingIgnoredDuringExecution: - - weight: 1 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: app - operator: In - values: - - ${RESOURCE_NAME} - topologyKey: kubernetes.io/hostname - - weight: 100 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: ${IMAGE_HASH} - operator: In - values: - - "True" - topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 - serviceAccount: default - nodeSelector: - node-lifecycle: normal - k8s.amazonaws.com/accelerator: ${GPU_TYPE} - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - priorityClassName: ${PRIORITY} - containers: - - name: celery-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} - imagePullPolicy: IfNotPresent - command: - - /usr/bin/dumb-init - - -- - - ddtrace-run - - run-service - - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - - --queue - - "${QUEUE}" - - --task-visibility - - "VISIBILITY_24H" - - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" - - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" - - --concurrency - - "${PER_WORKER}" - env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" - - name: DD_SERVICE - value: "${ENDPOINT_NAME}" - - name: DD_ENV - value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - - name: DD_AGENT_HOST - valueFrom: - fieldRef: - fieldPath: status.hostIP - - name: AWS_PROFILE - value: "${AWS_ROLE}" - - name: RESULTS_S3_BUCKET - value: "${RESULTS_S3_BUCKET}" - - name: BASE_PATH - value: "/workspace" - - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" - - name: CELERY_QUEUE - value: "${QUEUE}" - - name: CELERY_TASK_VISIBILITY - value: "VISIBILITY_24H" - - name: S3_BUCKET - value: "${CELERY_S3_BUCKET}" - resources: - requests: - cpu: 0.1 - memory: "100M" - ephemeral-storage: "100M" - limits: - cpu: ${FORWARDER_CPUS_LIMIT} - memory: ${FORWARDER_MEMORY_LIMIT} - ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} - - - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: user-config - mountPath: /workspace/user_config - subPath: raw_data - - name: endpoint-config - mountPath: /workspace/endpoint_config - subPath: raw_data - - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + name: dshm - name: main securityContext: capabilities: @@ -2063,8 +1469,10 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: + nvidia.com/gpu: ${GPUS} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -2075,16 +1483,12 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -2100,7 +1504,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -2112,11 +1516,11 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml - deployment-artifact-async-gpu.yaml: |- + deployment-runnable-image-async-gpu.yaml: |- apiVersion: apps/v1 kind: Deployment metadata: @@ -2129,10 +1533,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -2164,10 +1568,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -2198,10 +1602,9 @@ data: values: - "True" topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 + terminationGracePeriodSeconds: 1800 serviceAccount: default nodeSelector: - node-lifecycle: normal k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: - key: "nvidia.com/gpu" @@ -2209,79 +1612,102 @@ data: effect: "NoSchedule" priorityClassName: ${PRIORITY} containers: - - image: ${IMAGE} + - name: celery-forwarder + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent - name: main - securityContext: - capabilities: - drop: - - all + command: + - /usr/bin/dumb-init + - -- + - python + - -m + - model_engine_server.inference.forwarding.celery_forwarder + - --config + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} + - --queue + - "${QUEUE}" + - --task-visibility + - "VISIBILITY_24H" + - --set + - "forwarder.async.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.async.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --num-workers + - "${PER_WORKER}" + - --broker-type + - redis env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - - name: OMP_NUM_THREADS - value: "1" - - name: BASE_PATH - value: "${BASE_PATH}" - - name: BUNDLE_URL - value: "${BUNDLE_URL}" - - name: LOAD_PREDICT_FN_MODULE_PATH - value: "${LOAD_PREDICT_FN_MODULE_PATH}" - - name: LOAD_MODEL_FN_MODULE_PATH - value: "${LOAD_MODEL_FN_MODULE_PATH}" - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - - name: CHILD_FN_INFO - value: "${CHILD_FN_INFO}" - - name: PREWARM - value: "${PREWARM}" + - name: BASE_PATH + value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" - - name: CELERY_S3_BUCKET - value: "${CELERY_S3_BUCKET}" - - name: BROKER_TYPE - value: "${BROKER_TYPE}" - - name: SQS_PROFILE - value: "${SQS_PROFILE}" - - name: SQS_QUEUE_NAME + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" + - name: CELERY_QUEUE value: "${QUEUE}" - - name: SQS_QUEUE_URL - value: "${SQS_QUEUE_URL}" + - name: CELERY_TASK_VISIBILITY + value: "VISIBILITY_24H" + - name: S3_BUCKET + value: "${CELERY_S3_BUCKET}" + resources: + requests: + cpu: 0.1 + memory: "100M" + ephemeral-storage: "100M" + limits: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: ${FORWARDER_MEMORY_LIMIT} + ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} + + + volumeMounts: + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + - name: user-config + mountPath: /workspace/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /workspace/endpoint_config + subPath: raw_data + - name: infra-service-config-volume + mountPath: /workspace/model-engine/model_engine_server/core/configs + - name: main + securityContext: + capabilities: + drop: + - all + image: ${IMAGE} + imagePullPolicy: IfNotPresent + command: ${COMMAND} + env: ${MAIN_ENV} readinessProbe: - exec: - command: - - cat - - /tmp/readyz - initialDelaySeconds: 2 - periodSeconds: 2 - failureThreshold: 100 - command: [ "dumb-init", "--", "ddtrace-run" ] - # Not including --pool=solo means there's a worker process and a separate supervisor process - # meaning if the worker crashes (because of OOM or something) the supervisor process can mark the task as - # failed, which should get rid of infinite task retries - args: - - celery - - --app=llm_engine.inference.async_inference - - worker - - --loglevel=INFO - - --concurrency=1 - - --queues=${QUEUE} - - -O - - fair + httpGet: + path: ${HEALTHCHECK_ROUTE} + port: ${USER_CONTAINER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 resources: requests: + nvidia.com/gpu: ${GPUS} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -2292,26 +1718,28 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /opt/.aws/config subPath: config + - mountPath: /dev/shm + name: dshm + - name: infra-service-config-volume + mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - name: user-config - mountPath: ${BASE_PATH}/user_config + mountPath: /app/user_config subPath: raw_data - name: endpoint-config - mountPath: ${BASE_PATH}/endpoint_config + mountPath: /app/endpoint_config subPath: raw_data - - name: infra-service-config-volume - mountPath: ${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs + ports: + - containerPort: ${USER_CONTAINER_PORT} + name: http # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 securityContext: fsGroup: 65534 volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -2323,7 +1751,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -2340,10 +1768,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -2368,10 +1796,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -2404,7 +1832,6 @@ data: terminationGracePeriodSeconds: 600 serviceAccount: default nodeSelector: - node-lifecycle: normal k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: - key: "nvidia.com/gpu" @@ -2413,46 +1840,51 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - - --http - - production_threads + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --port - "${FORWARDER_PORT}" - - --concurrency - - "${PER_WORKER}" + - --num-workers + - "${FORWARDER_WORKER_COUNT}" + - --set + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: HTTP_HOST value: "0.0.0.0" readinessProbe: @@ -2461,9 +1893,10 @@ data: port: ${FORWARDER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: - cpu: 0.1 + cpu: ${FORWARDER_CPUS_LIMIT} memory: "100M" ephemeral-storage: "100M" limits: @@ -2474,7 +1907,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -2483,12 +1916,12 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs ports: - containerPort: ${FORWARDER_PORT} name: http - name: tritonserver - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/std-ml-srv:${TRITON_COMMIT_TAG}-triton + image: nvidia/tritonserver:${TRITON_COMMIT_TAG}-triton imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init @@ -2499,6 +1932,8 @@ data: env: - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" ports: - containerPort: 8000 name: http @@ -2527,7 +1962,7 @@ data: ${TRITON_STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm @@ -2546,8 +1981,10 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: + nvidia.com/gpu: ${GPUS} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -2558,16 +1995,12 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -2583,7 +2016,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -2595,7 +2028,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -2612,10 +2045,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -2640,10 +2073,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -2676,7 +2109,6 @@ data: terminationGracePeriodSeconds: 600 serviceAccount: default nodeSelector: - node-lifecycle: normal k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: - key: "nvidia.com/gpu" @@ -2685,46 +2117,51 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - - run-service + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/llm_engine/llm_engine/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - - --http - - production_threads + - /workspace/model-engine/model_engine_server/inference/configs/${FORWARDER_CONFIG_FILE_NAME} - --port - "${FORWARDER_PORT}" - - --concurrency - - "${PER_WORKER}" + - --num-workers + - "${FORWARDER_WORKER_COUNT}" + - --set + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set - - "forwarder.model.args.predict_route=${PREDICT_ROUTE}" + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - - "forwarder.model.args.healthcheck_route=${HEALTHCHECK_ROUTE}" + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: HTTP_HOST value: "0.0.0.0" readinessProbe: @@ -2733,9 +2170,10 @@ data: port: ${FORWARDER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: - cpu: 0.1 + cpu: ${FORWARDER_CPUS_LIMIT} memory: "100M" ephemeral-storage: "100M" limits: @@ -2746,228 +2184,38 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config subPath: raw_data - name: endpoint-config - mountPath: /workspace/endpoint_config - subPath: raw_data - - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs - ports: - - containerPort: ${FORWARDER_PORT} - name: http - - name: main - securityContext: - capabilities: - drop: - - all - image: ${IMAGE} - imagePullPolicy: IfNotPresent - command: ${COMMAND} - env: ${MAIN_ENV} - readinessProbe: - httpGet: - path: ${HEALTHCHECK_ROUTE} - port: ${USER_CONTAINER_PORT} - initialDelaySeconds: ${READINESS_INITIAL_DELAY} - periodSeconds: 5 - resources: - requests: - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - limits: - nvidia.com/gpu: ${GPUS} - cpu: ${CPUS} - memory: ${MEMORY} - ${STORAGE_DICT} - volumeMounts: - - name: config-volume - mountPath: /root/.aws/config - subPath: config - - mountPath: /dev/shm - name: dshm - - name: infra-service-config-volume - mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - - name: user-config - mountPath: /app/user_config - subPath: raw_data - - name: endpoint-config - mountPath: /app/endpoint_config - subPath: raw_data - ports: - - containerPort: ${USER_CONTAINER_PORT} - name: http - # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 - securityContext: - fsGroup: 65534 - volumes: - - name: config-volume - configMap: - name: default-config - - name: user-config - configMap: - name: ${RESOURCE_NAME} - - name: endpoint-config - configMap: - name: ${RESOURCE_NAME}-endpoint-config - - name: dshm - emptyDir: - medium: Memory - - name: infra-service-config-volume - configMap: - name: llm-engine-service-config - items: - - key: infra_service_config - path: config.yaml - deployment-artifact-sync-gpu.yaml: |- - apiVersion: apps/v1 - kind: Deployment - metadata: - name: ${RESOURCE_NAME} - namespace: ${NAMESPACE} - labels: - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - spec: - strategy: - type: RollingUpdate - rollingUpdate: - maxSurge: 1 - maxUnavailable: 0 - replicas: ${MIN_WORKERS} - selector: - matchLabels: - app: ${RESOURCE_NAME} - version: v1 - template: - metadata: - labels: - app: ${RESOURCE_NAME} - user_id: ${OWNER} - team: ${TEAM} - product: ${PRODUCT} - created_by: ${CREATED_BY} - owner: ${OWNER} - env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" - tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - tags.datadoghq.com/service: ${ENDPOINT_NAME} - endpoint_id: ${ENDPOINT_ID} - endpoint_name: ${ENDPOINT_NAME} - version: v1 - annotations: - ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' - kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" - spec: - affinity: - podAffinity: - preferredDuringSchedulingIgnoredDuringExecution: - - weight: 1 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: app - operator: In - values: - - ${RESOURCE_NAME} - topologyKey: kubernetes.io/hostname - - weight: 100 - podAffinityTerm: - labelSelector: - matchExpressions: - - key: ${IMAGE_HASH} - operator: In - values: - - "True" - topologyKey: kubernetes.io/hostname - terminationGracePeriodSeconds: 600 - serviceAccount: default - nodeSelector: - node-lifecycle: normal - k8s.amazonaws.com/accelerator: ${GPU_TYPE} - tolerations: - - key: "nvidia.com/gpu" - operator: "Exists" - effect: "NoSchedule" - priorityClassName: ${PRIORITY} - containers: - - image: ${IMAGE} - imagePullPolicy: IfNotPresent - name: main + mountPath: /workspace/endpoint_config + subPath: raw_data + - name: infra-service-config-volume + mountPath: /workspace/model-engine/model_engine_server/core/configs + ports: + - containerPort: ${FORWARDER_PORT} + name: http + - name: main securityContext: capabilities: drop: - all - env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" - - name: DD_SERVICE - value: "${ENDPOINT_NAME}" - - name: DD_ENV - value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - - name: DD_AGENT_HOST - valueFrom: - fieldRef: - fieldPath: status.hostIP - - name: OMP_NUM_THREADS - value: "1" - - name: BASE_PATH - value: "${BASE_PATH}" - - name: BUNDLE_URL - value: "${BUNDLE_URL}" - - name: LOAD_PREDICT_FN_MODULE_PATH - value: "${LOAD_PREDICT_FN_MODULE_PATH}" - - name: LOAD_MODEL_FN_MODULE_PATH - value: "${LOAD_MODEL_FN_MODULE_PATH}" - - name: AWS_PROFILE - value: "${AWS_ROLE}" - - name: RESULTS_S3_BUCKET - value: "${RESULTS_S3_BUCKET}" - - name: CHILD_FN_INFO - value: "${CHILD_FN_INFO}" - - name: PREWARM - value: "${PREWARM}" - - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" - - name: PORT - value: "${ARTIFACT_LIKE_CONTAINER_PORT}" + image: ${IMAGE} + imagePullPolicy: IfNotPresent + command: ${COMMAND} + env: ${MAIN_ENV} readinessProbe: httpGet: - path: /readyz - port: ${ARTIFACT_LIKE_CONTAINER_PORT} - initialDelaySeconds: 2 - periodSeconds: 2 - failureThreshold: 100 - command: [ "dumb-init", "--", "ddtrace-run" ] - args: - - python - - -m - - llm_engine.inference.sync_inference.start_fastapi_server + path: ${HEALTHCHECK_ROUTE} + port: ${USER_CONTAINER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 resources: requests: + nvidia.com/gpu: ${GPUS} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -2978,26 +2226,28 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config - subPath: config - - name: config-volume - mountPath: /home/llmengine/.aws/config + mountPath: /opt/.aws/config subPath: config + - mountPath: /dev/shm + name: dshm + - name: infra-service-config-volume + mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - name: user-config - mountPath: ${BASE_PATH}/user_config + mountPath: /app/user_config subPath: raw_data - name: endpoint-config - mountPath: ${BASE_PATH}/endpoint_config + mountPath: /app/endpoint_config subPath: raw_data - - name: infra-service-config-volume - mountPath: ${BASE_PATH}/ml_infra_core/llm_engine.core/llm_engine.core/configs + ports: + - containerPort: ${USER_CONTAINER_PORT} + name: http # Workaround for https://github.com/kubernetes-sigs/external-dns/pull/1185 securityContext: fsGroup: 65534 volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -3009,7 +2259,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -3026,10 +2276,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3054,10 +2304,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3090,7 +2340,6 @@ data: terminationGracePeriodSeconds: 600 serviceAccount: default nodeSelector: - node-lifecycle: normal k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: - key: "nvidia.com/gpu" @@ -3099,21 +2348,20 @@ data: priorityClassName: ${PRIORITY} containers: - name: http-forwarder - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:${FORWARDER_IMAGE_TAG} + image: model-engine:${GIT_TAG} imagePullPolicy: IfNotPresent command: - /usr/bin/dumb-init - -- - - ddtrace-run - python - -m - - llm_engine.inference.forwarding.http_forwarder + - model_engine_server.inference.forwarding.http_forwarder - --config - - /workspace/llm_engine/llm_engine/inference/configs/service--http_forwarder.yaml + - /workspace/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml - --port - "${FORWARDER_PORT}" - --num-workers - - "${PER_WORKER}" + - "${FORWARDER_WORKER_COUNT}" - --set - "forwarder.sync.predict_route=${PREDICT_ROUTE}" - --set @@ -3122,27 +2370,35 @@ data: - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" - --set - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.sync.extra_routes=${FORWARDER_EXTRA_ROUTES}" + - --set + - "forwarder.stream.extra_routes=${FORWARDER_EXTRA_ROUTES}" env: - - name: DATADOG_TRACE_ENABLED - value: "${DATADOG_TRACE_ENABLED}" + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_SERVICE value: "${ENDPOINT_NAME}" - name: DD_ENV value: circleci - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + value: "${GIT_TAG}" - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - name: AWS_PROFILE value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: RESULTS_S3_BUCKET value: "${RESULTS_S3_BUCKET}" - name: BASE_PATH value: "/workspace" - name: ML_INFRA_SERVICES_CONFIG_PATH - value: "/workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml" + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" - name: HTTP_HOST value: "0.0.0.0" readinessProbe: @@ -3151,9 +2407,10 @@ data: port: ${FORWARDER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: - cpu: 0.1 + cpu: ${FORWARDER_CPUS_LIMIT} memory: "100M" ephemeral-storage: "100M" limits: @@ -3164,7 +2421,7 @@ data: volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: user-config mountPath: /workspace/user_config @@ -3173,7 +2430,7 @@ data: mountPath: /workspace/endpoint_config subPath: raw_data - name: infra-service-config-volume - mountPath: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs + mountPath: /workspace/model-engine/model_engine_server/core/configs ports: - containerPort: ${FORWARDER_PORT} name: http @@ -3192,8 +2449,10 @@ data: port: ${USER_CONTAINER_PORT} initialDelaySeconds: ${READINESS_INITIAL_DELAY} periodSeconds: 5 + timeoutSeconds: 5 resources: requests: + nvidia.com/gpu: ${GPUS} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -3204,16 +2463,12 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - mountPath: /dev/shm name: dshm - name: infra-service-config-volume mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} - # LIRA: For compatibility with runnable image converted from artifactlike bundle - - name: config-volume - mountPath: /home/llmengine/.aws/config - subPath: config - name: user-config mountPath: /app/user_config subPath: raw_data @@ -3229,7 +2484,7 @@ data: volumes: - name: config-volume configMap: - name: default-config + name: default-config - name: user-config configMap: name: ${RESOURCE_NAME} @@ -3241,7 +2496,7 @@ data: medium: Memory - name: infra-service-config-volume configMap: - name: llm-engine-service-config + name: model-engine-service-config items: - key: infra_service_config path: config.yaml @@ -3258,10 +2513,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3280,10 +2535,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3302,10 +2557,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3324,6 +2579,382 @@ data: target: type: Value averageValue: ${CONCURRENCY} + keda-scaled-object.yaml: |- + apiVersion: keda.sh/v1alpha1 + kind: ScaledObject + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + scaleTargetRef: + name: ${RESOURCE_NAME} + pollingInterval: 5 + cooldownPeriod: 300 + minReplicaCount: ${MIN_WORKERS} + maxReplicaCount: ${MAX_WORKERS} + fallback: + failureThreshold: 3 + replicas: ${MIN_WORKERS} + triggers: + - type: redis + metadata: + address: ${REDIS_HOST_PORT} # Format must be host:port + passwordFromEnv: "" + listName: "launch-endpoint-autoscaling:${ENDPOINT_ID}" + listLength: "100" # something absurdly high so we don't scale past 1 pod + activationListLength: "0" + enableTLS: "false" + unsafeSsl: "false" + databaseIndex: "${REDIS_DB_INDEX}" + - type: prometheus + metadata: + threshold: "${CONCURRENCY}" + metricName: request_concurrency_average + query: sum(rate(istio_request_duration_milliseconds_sum{destination_workload="${RESOURCE_NAME}"}[2m])) / 1000 + serverAddress: ${PROMETHEUS_SERVER_ADDRESS} + leader-worker-set-streaming-gpu.yaml: |- + apiVersion: leaderworkerset.x-k8s.io/v1 + kind: LeaderWorkerSet + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + replicas: ${MIN_WORKERS} + leaderWorkerTemplate: + size: ${LWS_SIZE} + restartPolicy: RecreateGroupOnPodRestart # TODO un-hardcode? if necessary + leaderTemplate: + metadata: + labels: + app: ${RESOURCE_NAME} + role: leader + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + sidecar.istio.io/inject: "false" # Never inject istio, it screws up networking + version: v1 + annotations: + ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' + kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" + spec: + affinity: + podAffinity: + preferredDuringSchedulingIgnoredDuringExecution: + - weight: 1 + podAffinityTerm: + labelSelector: + matchExpressions: + - key: app + operator: In + values: + - ${RESOURCE_NAME} + topologyKey: kubernetes.io/hostname + - weight: 100 + podAffinityTerm: + labelSelector: + matchExpressions: + - key: ${IMAGE_HASH} + operator: In + values: + - "True" + topologyKey: kubernetes.io/hostname + terminationGracePeriodSeconds: 600 + serviceAccount: default + nodeSelector: + k8s.amazonaws.com/accelerator: ${GPU_TYPE} + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + priorityClassName: ${PRIORITY} + containers: + - name: http-forwarder + image: model-engine:${GIT_TAG} + imagePullPolicy: IfNotPresent + command: + - /usr/bin/dumb-init + - -- + - python + - -m + - model_engine_server.inference.forwarding.http_forwarder + - --config + - /workspace/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml + - --port + - "${FORWARDER_PORT}" + - --num-workers + - "${FORWARDER_WORKER_COUNT}" + - --set + - "forwarder.sync.predict_route=${PREDICT_ROUTE}" + - --set + - "forwarder.stream.predict_route=${STREAMING_PREDICT_ROUTE}" + - --set + - "forwarder.sync.healthcheck_route=${HEALTHCHECK_ROUTE}" + - --set + - "forwarder.stream.healthcheck_route=${HEALTHCHECK_ROUTE}" + env: + - name: DD_TRACE_ENABLED + value: "${DD_TRACE_ENABLED}" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" + - name: DD_SERVICE + value: "${ENDPOINT_NAME}" + - name: DD_ENV + value: circleci + - name: DD_VERSION + value: "${GIT_TAG}" + - name: DD_AGENT_HOST + valueFrom: + fieldRef: + fieldPath: status.hostIP + - name: AWS_PROFILE + value: "${AWS_ROLE}" + - name: AWS_CONFIG_FILE + value: /opt/.aws/config + - name: RESULTS_S3_BUCKET + value: "${RESULTS_S3_BUCKET}" + - name: BASE_PATH + value: "/workspace" + - name: ML_INFRA_SERVICES_CONFIG_PATH + value: "/workspace/model-engine/model_engine_server/core/configs/config.yaml" + - name: HTTP_HOST + value: "0.0.0.0" + readinessProbe: + httpGet: + path: /readyz + port: ${FORWARDER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: "100M" + ephemeral-storage: "100M" + limits: + cpu: ${FORWARDER_CPUS_LIMIT} + memory: ${FORWARDER_MEMORY_LIMIT} + ephemeral-storage: ${FORWARDER_STORAGE_LIMIT} + + + volumeMounts: + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + - name: user-config + mountPath: /workspace/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /workspace/endpoint_config + subPath: raw_data + - name: infra-service-config-volume + mountPath: /workspace/model-engine/model_engine_server/core/configs + ports: + - containerPort: ${FORWARDER_PORT} + name: http + - name: lws-leader + image: ${IMAGE} + imagePullPolicy: IfNotPresent + command: ${COMMAND} + env: ${MAIN_ENV} + readinessProbe: + httpGet: + path: ${HEALTHCHECK_ROUTE} + port: ${USER_CONTAINER_PORT} + initialDelaySeconds: ${READINESS_INITIAL_DELAY} + periodSeconds: 5 + timeoutSeconds: 5 + resources: + requests: + nvidia.com/gpu: ${GPUS} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + limits: + nvidia.com/gpu: ${GPUS} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + volumeMounts: + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + - mountPath: /dev/shm + name: dshm + - name: infra-service-config-volume + mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} + - name: user-config + mountPath: /app/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /app/endpoint_config + subPath: raw_data + ports: + - containerPort: ${USER_CONTAINER_PORT} + name: http + volumes: + - name: config-volume + configMap: + name: default-config + - name: user-config + configMap: + name: ${RESOURCE_NAME} + - name: endpoint-config + configMap: + name: ${RESOURCE_NAME}-endpoint-config + - name: dshm + emptyDir: + medium: Memory + - name: infra-service-config-volume + configMap: + name: model-engine-service-config + items: + - key: infra_service_config + path: config.yaml + workerTemplate: + metadata: + labels: + app: ${RESOURCE_NAME} + role: worker + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + sidecar.istio.io/inject: "false" # Never inject istio for LWS, it screws up networking + version: v1 + annotations: + ad.datadoghq.com/main.logs: '[{"service": "${ENDPOINT_NAME}", "source": "python"}]' + kubernetes.io/change-cause: "${CHANGE_CAUSE_MESSAGE}" + spec: + affinity: + podAffinity: + preferredDuringSchedulingIgnoredDuringExecution: + - weight: 1 + podAffinityTerm: + labelSelector: + matchExpressions: + - key: app + operator: In + values: + - ${RESOURCE_NAME} + topologyKey: kubernetes.io/hostname + - weight: 100 + podAffinityTerm: + labelSelector: + matchExpressions: + - key: ${IMAGE_HASH} + operator: In + values: + - "True" + topologyKey: kubernetes.io/hostname + terminationGracePeriodSeconds: 600 + serviceAccount: default + nodeSelector: + k8s.amazonaws.com/accelerator: ${GPU_TYPE} + tolerations: + - key: "nvidia.com/gpu" + operator: "Exists" + effect: "NoSchedule" + priorityClassName: ${PRIORITY} + containers: + - name: lws-worker + image: ${IMAGE} + imagePullPolicy: IfNotPresent + command: ${WORKER_COMMAND} + env: ${WORKER_ENV} + resources: + requests: + nvidia.com/gpu: ${GPUS} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + limits: + nvidia.com/gpu: ${GPUS} + cpu: ${CPUS} + memory: ${MEMORY} + ${STORAGE_DICT} + volumeMounts: + - name: config-volume + mountPath: /opt/.aws/config + subPath: config + - mountPath: /dev/shm + name: dshm + - name: infra-service-config-volume + mountPath: ${INFRA_SERVICE_CONFIG_VOLUME_MOUNT_PATH} + - name: user-config + mountPath: /app/user_config + subPath: raw_data + - name: endpoint-config + mountPath: /app/endpoint_config + subPath: raw_data + ports: + - containerPort: ${USER_CONTAINER_PORT} + name: http + volumes: + - name: config-volume + configMap: + name: default-config + - name: user-config + configMap: + name: ${RESOURCE_NAME} + - name: endpoint-config + configMap: + name: ${RESOURCE_NAME}-endpoint-config + - name: dshm + emptyDir: + medium: Memory + - name: infra-service-config-volume + configMap: + name: model-engine-service-config + items: + - key: infra_service_config + path: config.yaml # mode # device service.yaml: |- apiVersion: v1 kind: Service @@ -3337,10 +2968,40 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + type: ${SERVICE_TYPE} + selector: + app: ${RESOURCE_NAME} + ports: + - port: 80 + targetPort: ${SERVICE_TARGET_PORT} + protocol: TCP + name: http + ${NODE_PORT_DICT} + lws-service.yaml: |- + apiVersion: v1 + kind: Service + metadata: + name: ${SERVICE_NAME_OVERRIDE} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3348,12 +3009,98 @@ data: type: ${SERVICE_TYPE} selector: app: ${RESOURCE_NAME} + role: leader ports: - port: 80 targetPort: ${SERVICE_TARGET_PORT} protocol: TCP name: http ${NODE_PORT_DICT} + virtual-service.yaml: |- + apiVersion: networking.istio.io/v1alpha3 + kind: VirtualService + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + hosts: + - ${RESOURCE_NAME}.${DNS_HOST_DOMAIN} + gateways: + - default/internal-gateway + http: + - route: + - destination: + host: "${RESOURCE_NAME}.${NAMESPACE}.svc.cluster.local" + port: + number: 80 + destination-rule.yaml: |- + apiVersion: networking.istio.io/v1beta1 + kind: DestinationRule + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + host: "${RESOURCE_NAME}.${NAMESPACE}.svc.cluster.local" + trafficPolicy: + loadBalancer: + simple: LEAST_REQUEST + lws-service-entry.yaml: |- + apiVersion: networking.istio.io/v1beta1 + kind: ServiceEntry + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + hosts: + - "${SERVICE_NAME_OVERRIDE}.${NAMESPACE}.svc.cluster.local" + location: MESH_EXTERNAL + ports: + - number: 80 + name: http + protocol: HTTP + resolution: NONE vertical-pod-autoscaler.yaml: |- apiVersion: "autoscaling.k8s.io/v1" kind: VerticalPodAutoscaler @@ -3366,10 +3113,10 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + tags.datadoghq.com/version: ${GIT_TAG} tags.datadoghq.com/service: ${ENDPOINT_NAME} endpoint_id: ${ENDPOINT_ID} endpoint_name: ${ENDPOINT_NAME} @@ -3392,6 +3139,31 @@ data: cpu: ${CPUS} memory: ${MEMORY} controlledResources: ["cpu", "memory"] + pod-disruption-budget.yaml: |- + apiVersion: policy/v1 + kind: PodDisruptionBudget + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + env: circleci + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/env: circleci + tags.datadoghq.com/version: ${GIT_TAG} + tags.datadoghq.com/service: ${ENDPOINT_NAME} + endpoint_id: ${ENDPOINT_ID} + endpoint_name: ${ENDPOINT_NAME} + spec: + maxUnavailable: 50% + selector: + matchLabels: + app: ${RESOURCE_NAME} batch-job-orchestration-job.yaml: |- apiVersion: batch/v1 kind: Job @@ -3404,12 +3176,15 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - llm_engine_job_id: ${JOB_ID} + tags.datadoghq.com/version: ${GIT_TAG} + launch_job_id: ${JOB_ID} + tags.datadoghq.com/request_id: ${REQUEST_ID} tags.datadoghq.com/service: ${JOB_ID} + tags.datadoghq.com/user_id: ${OWNER} + tags.datadoghq.com/team: ${TEAM} spec: backoffLimit: 0 activeDeadlineSeconds: ${BATCH_JOB_MAX_RUNTIME} @@ -3423,71 +3198,85 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - llm_engine_job_id: ${JOB_ID} + tags.datadoghq.com/version: ${GIT_TAG} + launch_job_id: ${JOB_ID} + tags.datadoghq.com/request_id: ${REQUEST_ID} tags.datadoghq.com/service: ${JOB_ID} + tags.datadoghq.com/user_id: ${OWNER} + tags.datadoghq.com/team: ${TEAM} sidecar.istio.io/inject: "false" version: v1 annotations: - ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "llm_engine_job_id:${JOB_ID}"]}]' + ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "launch_job_id:${JOB_ID}"]}]' cluster-autoscaler.kubernetes.io/safe-to-evict: "false" spec: restartPolicy: Never - nodeSelector: - node-lifecycle: normal - serviceAccountName: llm-engine + serviceAccountName: model-engine volumes: - name: config-volume configMap: name: default-config containers: - name: main - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + image: model-engine:${GIT_TAG} env: - name: DD_SERVICE value: ${RESOURCE_NAME} - - name: DATADOG_TRACE_ENABLED + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" + - name: DD_TRACE_ENABLED value: "true" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_ENV value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - - name: GIT_TAG - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - name: SERVICE_IDENTIFIER + - name: GATEWAY_URL + value: http://model-engine.default:80 - name: AWS_PROFILE value: default + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: ECR_READ_AWS_PROFILE value: default + - name: DB_SECRET_AWS_PROFILE + value: default + - name: S3_WRITE_AWS_PROFILE + value: default - name: ML_INFRA_DATABASE_URL valueFrom: secretKeyRef: key: database_url - name: ml-infra-pg + name: model-engine-postgres-credentials - name: DEPLOY_SERVICE_CONFIG_PATH - value: /workspace/llm_engine/service_configs/service_config.yaml + value: /workspace/model-engine/service_configs/service_config.yaml - name: ML_INFRA_SERVICES_CONFIG_PATH - value: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml + value: /workspace/model-engine/model_engine_server/core/configs/config.yaml - name: CELERY_ELASTICACHE_ENABLED value: "true" - - name: LLM_ENGINE_SERVICE_TEMPLATE_FOLDER - value: /workspace/llm_engine/llm_engine/infra/gateways/resources/templates - imagePullPolicy: Always + - name: LAUNCH_SERVICE_TEMPLATE_FOLDER + value: /workspace/model-engine/model_engine_server/infra/gateways/resources/templates + - name: CIRCLECI + value: "true" + - name: DD_VERSION + value: ${GIT_TAG} + - name: GIT_TAG + value: ${GIT_TAG} + imagePullPolicy: IfNotPresent command: - dumb-init - -- - - ddtrace-run args: - python - -m - - server.llm_engine_server.entrypoints.start_batch_job_orchestration + - model_engine_server.entrypoints.start_batch_job_orchestration - --job-id - ${JOB_ID} - --owner @@ -3503,12 +3292,14 @@ data: requests: cpu: 1 memory: 8Gi + ephemeral-storage: 10Gi limits: cpu: 4 memory: 32Gi + ephemeral-storage: 30Gi volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config docker-image-batch-job-cpu.yaml: |- apiVersion: batch/v1 @@ -3522,16 +3313,22 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - llm_engine_job_id: ${JOB_ID} + tags.datadoghq.com/version: ${GIT_TAG} + launch_job_id: ${JOB_ID} + tags.datadoghq.com/request_id: ${REQUEST_ID} tags.datadoghq.com/service: ${JOB_ID} + tags.datadoghq.com/user_id: ${OWNER} + tags.datadoghq.com/team: ${TEAM} spec: backoffLimit: 0 activeDeadlineSeconds: ${BATCH_JOB_MAX_RUNTIME} ttlSecondsAfterFinished: ${BATCH_JOB_TTL_SECONDS_AFTER_FINISHED} + completions: ${BATCH_JOB_NUM_WORKERS} + parallelism: ${BATCH_JOB_NUM_WORKERS} + completionMode: "Indexed" template: metadata: labels: @@ -3541,20 +3338,22 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - llm_engine_job_id: ${JOB_ID} + tags.datadoghq.com/version: ${GIT_TAG} + launch_job_id: ${JOB_ID} + tags.datadoghq.com/request_id: ${REQUEST_ID} tags.datadoghq.com/service: ${JOB_ID} + tags.datadoghq.com/user_id: ${OWNER} + tags.datadoghq.com/team: ${TEAM} sidecar.istio.io/inject: "false" version: v1 annotations: - ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "llm_engine_job_id:${JOB_ID}"]}]' + ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "launch_job_id:${JOB_ID}"]}]' + cluster-autoscaler.kubernetes.io/safe-to-evict: "false" spec: restartPolicy: Never - nodeSelector: - node-lifecycle: normal serviceAccountName: default volumes: - name: config-volume @@ -3571,37 +3370,51 @@ data: env: - name: DD_SERVICE value: ${RESOURCE_NAME} - - name: DATADOG_TRACE_ENABLED + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" + - name: DD_TRACE_ENABLED value: "true" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_ENV value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - - name: GIT_TAG - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - name: SERVICE_IDENTIFIER + - name: GATEWAY_URL + value: http://model-engine.default:80 - name: AWS_PROFILE value: default + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: ECR_READ_AWS_PROFILE value: default + - name: DB_SECRET_AWS_PROFILE + value: default + - name: S3_WRITE_AWS_PROFILE + value: default - name: ML_INFRA_DATABASE_URL valueFrom: secretKeyRef: key: database_url - name: ml-infra-pg + name: model-engine-postgres-credentials - name: DEPLOY_SERVICE_CONFIG_PATH - value: /workspace/llm_engine/service_configs/service_config.yaml + value: /workspace/model-engine/service_configs/service_config.yaml - name: ML_INFRA_SERVICES_CONFIG_PATH - value: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml + value: /workspace/model-engine/model_engine_server/core/configs/config.yaml - name: CELERY_ELASTICACHE_ENABLED value: "true" - - name: LLM_ENGINE_SERVICE_TEMPLATE_FOLDER - value: /workspace/llm_engine/llm_engine/infra/gateways/resources/templates - imagePullPolicy: Always + - name: LAUNCH_SERVICE_TEMPLATE_FOLDER + value: /workspace/model-engine/model_engine_server/infra/gateways/resources/templates + - name: CIRCLECI + value: "true" + - name: DD_VERSION + value: ${GIT_TAG} + - name: GIT_TAG + value: ${GIT_TAG} + imagePullPolicy: IfNotPresent command: ${COMMAND} resources: # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. @@ -3615,7 +3428,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: workdir mountPath: ${MOUNT_PATH} @@ -3623,11 +3436,14 @@ data: name: dshm initContainers: - name: input-downloader - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + image: model-engine:${GIT_TAG} + env: + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" command: - python - -m - - server.llm_engine_server.entrypoints.start_docker_image_batch_job_init_container + - model_engine_server.entrypoints.start_docker_image_batch_job_init_container - ${INPUT_LOCATION} - --remote-file - ${S3_FILE} @@ -3644,7 +3460,7 @@ data: memory: 1Gi volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: workdir mountPath: ${MOUNT_PATH} @@ -3660,16 +3476,22 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - llm_engine_job_id: ${JOB_ID} + tags.datadoghq.com/version: ${GIT_TAG} + launch_job_id: ${JOB_ID} + tags.datadoghq.com/request_id: ${REQUEST_ID} tags.datadoghq.com/service: ${JOB_ID} + tags.datadoghq.com/user_id: ${OWNER} + tags.datadoghq.com/team: ${TEAM} spec: backoffLimit: 0 activeDeadlineSeconds: ${BATCH_JOB_MAX_RUNTIME} ttlSecondsAfterFinished: ${BATCH_JOB_TTL_SECONDS_AFTER_FINISHED} + completions: ${BATCH_JOB_NUM_WORKERS} + parallelism: ${BATCH_JOB_NUM_WORKERS} + completionMode: "Indexed" template: metadata: labels: @@ -3679,20 +3501,23 @@ data: created_by: ${CREATED_BY} owner: ${OWNER} env: circleci - managed-by: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + managed-by: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/env: circleci - tags.datadoghq.com/version: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - llm_engine_job_id: ${JOB_ID} + tags.datadoghq.com/version: ${GIT_TAG} + launch_job_id: ${JOB_ID} + tags.datadoghq.com/request_id: ${REQUEST_ID} tags.datadoghq.com/service: ${JOB_ID} + tags.datadoghq.com/user_id: ${OWNER} + tags.datadoghq.com/team: ${TEAM} sidecar.istio.io/inject: "false" version: v1 annotations: - ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "llm_engine_job_id:${JOB_ID}"]}]' + ad.datadoghq.com/main.logs: '[{"source": "python", "service": "${RESOURCE_NAME}", "tags": ["env:circleci", "launch_job_id:${JOB_ID}"]}]' + cluster-autoscaler.kubernetes.io/safe-to-evict: "false" spec: restartPolicy: Never nodeSelector: - node-lifecycle: normal k8s.amazonaws.com/accelerator: ${GPU_TYPE} tolerations: - key: "nvidia.com/gpu" @@ -3714,41 +3539,56 @@ data: env: - name: DD_SERVICE value: ${RESOURCE_NAME} - - name: DATADOG_TRACE_ENABLED + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" + - name: DD_TRACE_ENABLED value: "true" + - name: DD_REMOTE_CONFIGURATION_ENABLED + value: "false" - name: DD_ENV value: circleci - - name: DD_VERSION - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - name: DD_AGENT_HOST valueFrom: fieldRef: fieldPath: status.hostIP - - name: GIT_TAG - value: 54f8f73bfb1cce62a2b42326ccf9f49b5b145126 - name: SERVICE_IDENTIFIER + - name: GATEWAY_URL + value: http://model-engine.default:80 - name: AWS_PROFILE value: default + - name: AWS_CONFIG_FILE + value: /opt/.aws/config - name: ECR_READ_AWS_PROFILE value: default + - name: DB_SECRET_AWS_PROFILE + value: default + - name: S3_WRITE_AWS_PROFILE + value: default - name: ML_INFRA_DATABASE_URL valueFrom: secretKeyRef: key: database_url - name: ml-infra-pg + name: model-engine-postgres-credentials - name: DEPLOY_SERVICE_CONFIG_PATH - value: /workspace/llm_engine/service_configs/service_config.yaml + value: /workspace/model-engine/service_configs/service_config.yaml - name: ML_INFRA_SERVICES_CONFIG_PATH - value: /workspace/ml_infra_core/llm_engine.core/llm_engine.core/configs/config.yaml + value: /workspace/model-engine/model_engine_server/core/configs/config.yaml - name: CELERY_ELASTICACHE_ENABLED value: "true" - - name: LLM_ENGINE_SERVICE_TEMPLATE_FOLDER - value: /workspace/llm_engine/llm_engine/infra/gateways/resources/templates - imagePullPolicy: Always + - name: LAUNCH_SERVICE_TEMPLATE_FOLDER + value: /workspace/model-engine/model_engine_server/infra/gateways/resources/templates + - name: CIRCLECI + value: "true" + - name: DD_VERSION + value: ${GIT_TAG} + - name: GIT_TAG + value: ${GIT_TAG} + imagePullPolicy: IfNotPresent command: ${COMMAND} resources: # If job pods get evicted, then we can make "Guaranteed QoS" by setting requests = limits. requests: + nvidia.com/gpu: ${GPUS} cpu: ${CPUS} memory: ${MEMORY} ${STORAGE_DICT} @@ -3759,7 +3599,7 @@ data: ${STORAGE_DICT} volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: workdir mountPath: ${MOUNT_PATH} @@ -3767,11 +3607,14 @@ data: name: dshm initContainers: - name: input-downloader - image: 000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:54f8f73bfb1cce62a2b42326ccf9f49b5b145126 + image: model-engine:${GIT_TAG} + env: + - name: AWS_CONFIG_FILE + value: "/opt/.aws/config" command: - python - -m - - server.llm_engine_server.entrypoints.start_docker_image_batch_job_init_container + - model_engine_server.entrypoints.start_docker_image_batch_job_init_container - ${INPUT_LOCATION} - --remote-file - ${S3_FILE} @@ -3788,7 +3631,7 @@ data: memory: 1Gi volumeMounts: - name: config-volume - mountPath: /root/.aws/config + mountPath: /opt/.aws/config subPath: config - name: workdir mountPath: ${MOUNT_PATH} @@ -3800,8 +3643,8 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: selector: @@ -3815,8 +3658,8 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 sidecar.istio.io/inject: "false" @@ -3837,8 +3680,8 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: selector: @@ -3852,8 +3695,8 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 sidecar.istio.io/inject: "false" @@ -3878,8 +3721,8 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: selector: @@ -3893,8 +3736,8 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 sidecar.istio.io/inject: "false" @@ -3919,8 +3762,8 @@ data: namespace: ${NAMESPACE} labels: team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} spec: selector: @@ -3934,8 +3777,8 @@ data: labels: app: ${RESOURCE_NAME} team: infra - product: llm-engine - use_scale_llm_engine_endpoint_network_policy: "true" + product: model-engine + use_scale_launch_endpoint_network_policy: "true" tags.datadoghq.com/service: ${RESOURCE_NAME} version: v1 sidecar.istio.io/inject: "false" @@ -3952,3 +3795,94 @@ data: name: busybox command: ["/bin/sh", "-ec", "while : ; do sleep 30 ; done"] terminationGracePeriodSeconds: 0 + image-cache-h100.yaml: |- + apiVersion: apps/v1 + kind: DaemonSet + metadata: + name: ${RESOURCE_NAME} + namespace: ${NAMESPACE} + labels: + team: infra + product: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/service: ${RESOURCE_NAME} + spec: + selector: + matchLabels: + app: ${RESOURCE_NAME} + version: v1 + updateStrategy: + type: RollingUpdate + template: + metadata: + labels: + app: ${RESOURCE_NAME} + team: infra + product: model-engine + use_scale_launch_endpoint_network_policy: "true" + tags.datadoghq.com/service: ${RESOURCE_NAME} + version: v1 + sidecar.istio.io/inject: "false" + spec: + nodeSelector: + k8s.amazonaws.com/accelerator: nvidia-hopper-h100 + tolerations: + - effect: NoSchedule + key: nvidia.com/gpu + operator: Exists + containers: + - image: public.ecr.aws/docker/library/busybox:latest + imagePullPolicy: IfNotPresent + name: busybox + command: ["/bin/sh", "-ec", "while : ; do sleep 30 ; done"] + terminationGracePeriodSeconds: 0 + cron-trigger.yaml: |- + apiVersion: batch/v1 + kind: CronJob + metadata: + name: ${NAME} + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + launch_trigger_id: ${TRIGGER_ID} + tags.datadoghq.com/service: ${TRIGGER_ID} + spec: + schedule: "${CRON_SCHEDULE}" + successfulJobsHistoryLimit: 0 + failedJobsHistoryLimit: 0 + jobTemplate: + spec: + backoffLimit: 0 + activeDeadlineSeconds: ${BATCH_CURL_JOB_ACTIVE_DEADLINE_SECONDS} + template: + metadata: + labels: + user_id: ${OWNER} + team: ${TEAM} + product: ${PRODUCT} + created_by: ${CREATED_BY} + owner: ${OWNER} + launch_trigger_id: ${TRIGGER_ID} + tags.datadoghq.com/service: ${TRIGGER_ID} + spec: + containers: + - name: ${NAME} + image: curlimages/curl:7.72.0 + imagePullPolicy: IfNotPresent + command: + - curl + - -X + - 'POST' + - '${HOST}/v1/docker-image-batch-jobs' + - -H + - 'accept: application/json' + - -H + - 'Content-Type: application/json' + - -d + - '{ "docker_image_batch_job_bundle_id": "${DOCKER_IMAGE_BATCH_JOB_BUNDLE_ID}", "job_config": ${JOB_CONFIG}, "labels": ${JOB_METADATA} }' + - -u + - '${OWNER}:' + restartPolicy: Never diff --git a/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py new file mode 100644 index 00000000..a5020740 --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/s3_file_storage_gateway.py @@ -0,0 +1,79 @@ +import os +from typing import List, Optional + +from model_engine_server.core.config import infra_config +from model_engine_server.domain.gateways.file_storage_gateway import ( + FileMetadata, + FileStorageGateway, +) +from model_engine_server.infra.gateways import S3FilesystemGateway + + +def get_s3_key(owner: str, file_id: str): + return os.path.join(owner, file_id) + + +def get_s3_url(owner: str, file_id: str): + return f"s3://{infra_config().s3_bucket}/{get_s3_key(owner, file_id)}" + + +class S3FileStorageGateway(FileStorageGateway): + """ + Concrete implementation of a file storage gateway backed by S3. + """ + + def __init__(self): + self.filesystem_gateway = S3FilesystemGateway() + + async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]: + return self.filesystem_gateway.generate_signed_url(get_s3_url(owner, file_id)) + + async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: + try: + obj = self.filesystem_gateway.get_s3_client({}).head_object( + Bucket=infra_config().s3_bucket, + Key=get_s3_key(owner, file_id), + ) + return FileMetadata( + id=file_id, + filename=file_id, + size=obj.get("ContentLength"), + owner=owner, + updated_at=obj.get("LastModified"), + ) + except: # noqa: E722 + return None + + async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: + try: + with self.filesystem_gateway.open( + get_s3_url(owner, file_id), aws_profile=infra_config().profile_ml_worker + ) as f: + return f.read() + except: # noqa: E722 + return None + + async def upload_file(self, owner: str, filename: str, content: bytes) -> str: + with self.filesystem_gateway.open( + get_s3_url(owner, filename), mode="w", aws_profile=infra_config().profile_ml_worker + ) as f: + f.write(content.decode("utf-8")) + return filename + + async def delete_file(self, owner: str, file_id: str) -> bool: + try: + self.filesystem_gateway.get_s3_client({}).delete_object( + Bucket=infra_config().s3_bucket, + Key=get_s3_key(owner, file_id), + ) + return True + except: # noqa: E722 + return False + + async def list_files(self, owner: str) -> List[FileMetadata]: + objects = self.filesystem_gateway.get_s3_client({}).list_objects_v2( + Bucket=infra_config().s3_bucket, + Prefix=owner, + ) + files = [await self.get_file(owner, obj["Name"]) for obj in objects] + return [f for f in files if f is not None] diff --git a/server/llm_engine_server/infra/gateways/s3_filesystem_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_filesystem_gateway.py similarity index 77% rename from server/llm_engine_server/infra/gateways/s3_filesystem_gateway.py rename to model-engine/model_engine_server/infra/gateways/s3_filesystem_gateway.py index 4dab06ba..b0bf9e84 100644 --- a/server/llm_engine_server/infra/gateways/s3_filesystem_gateway.py +++ b/model-engine/model_engine_server/infra/gateways/s3_filesystem_gateway.py @@ -4,16 +4,15 @@ import boto3 import smart_open - -from . import FilesystemGateway +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway class S3FilesystemGateway(FilesystemGateway): """ - Concrete implemention for interacting with a filesystem backed by S3. + Concrete implementation for interacting with a filesystem backed by S3. """ - def _get_s3_client(self, kwargs): + def get_s3_client(self, kwargs): profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE")) session = boto3.Session(profile_name=profile_name) client = session.client("s3") @@ -21,12 +20,12 @@ def _get_s3_client(self, kwargs): def open(self, uri: str, mode: str = "rt", **kwargs) -> IO: # This follows the 5.1.0 smart_open API - client = self._get_s3_client(kwargs) + client = self.get_s3_client(kwargs) transport_params = {"client": client} return smart_open.open(uri, mode, transport_params=transport_params) def generate_signed_url(self, uri: str, expiration: int = 3600, **kwargs) -> str: - client = self._get_s3_client(kwargs) + client = self.get_s3_client(kwargs) match = re.search("^s3://([^/]+)/(.*?)$", uri) assert match diff --git a/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py new file mode 100644 index 00000000..b48d1eef --- /dev/null +++ b/model-engine/model_engine_server/infra/gateways/s3_llm_artifact_gateway.py @@ -0,0 +1,85 @@ +import json +import os +from typing import Any, Dict, List + +import boto3 +from model_engine_server.common.config import get_model_cache_directory_name, hmi_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.utils.url import parse_attachment_url +from model_engine_server.domain.gateways import LLMArtifactGateway + +logger = make_logger(logger_name()) + + +class S3LLMArtifactGateway(LLMArtifactGateway): + """ + Concrete implemention for interacting with a filesystem backed by S3. + """ + + def _get_s3_resource(self, kwargs): + profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE")) + session = boto3.Session(profile_name=profile_name) + resource = session.resource("s3") + return resource + + def list_files(self, path: str, **kwargs) -> List[str]: + s3 = self._get_s3_resource(kwargs) + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket = parsed_remote.bucket + key = parsed_remote.key + + s3_bucket = s3.Bucket(bucket) + files = [obj.key for obj in s3_bucket.objects.filter(Prefix=key)] + return files + + def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + s3 = self._get_s3_resource(kwargs) + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket = parsed_remote.bucket + key = parsed_remote.key + + s3_bucket = s3.Bucket(bucket) + downloaded_files: List[str] = [] + for obj in s3_bucket.objects.filter(Prefix=key): + file_path_suffix = obj.key.replace(key, "").lstrip("/") + local_path = os.path.join(target_path, file_path_suffix).rstrip("/") + + if not overwrite and os.path.exists(local_path): + downloaded_files.append(local_path) + continue + + local_dir = "/".join(local_path.split("/")[:-1]) + if not os.path.exists(local_dir): + os.makedirs(local_dir) + + logger.info(f"Downloading {obj.key} to {local_path}") + s3_bucket.download_file(obj.key, local_path) + downloaded_files.append(local_path) + return downloaded_files + + def get_model_weights_urls(self, owner: str, model_name: str, **kwargs) -> List[str]: + s3 = self._get_s3_resource(kwargs) + parsed_remote = parse_attachment_url( + hmi_config.hf_user_fine_tuned_weights_prefix, clean_key=False + ) + bucket = parsed_remote.bucket + fine_tuned_weights_prefix = parsed_remote.key + + s3_bucket = s3.Bucket(bucket) + model_files: List[str] = [] + model_cache_name = get_model_cache_directory_name(model_name) + prefix = f"{fine_tuned_weights_prefix}/{owner}/{model_cache_name}" + for obj in s3_bucket.objects.filter(Prefix=prefix): + model_files.append(f"s3://{bucket}/{obj.key}") + return model_files + + def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: + s3 = self._get_s3_resource(kwargs) + parsed_remote = parse_attachment_url(path, clean_key=False) + bucket = parsed_remote.bucket + key = os.path.join(parsed_remote.key, "config.json") + s3_bucket = s3.Bucket(bucket) + filepath = os.path.join("/tmp", key).replace("/", "_") + s3_bucket.download_file(key, filepath) + with open(filepath, "r") as f: + return json.load(f) diff --git a/server/llm_engine_server/infra/infra_utils.py b/model-engine/model_engine_server/infra/infra_utils.py similarity index 96% rename from server/llm_engine_server/infra/infra_utils.py rename to model-engine/model_engine_server/infra/infra_utils.py index db8e7182..b38083c9 100644 --- a/server/llm_engine_server/infra/infra_utils.py +++ b/model-engine/model_engine_server/infra/infra_utils.py @@ -2,7 +2,7 @@ from logging import LoggerAdapter from typing import Callable, Sequence -from llm_engine_server.common.env_vars import LOCAL +from model_engine_server.common.env_vars import LOCAL __all__: Sequence[str] = "make_exception_log" diff --git a/server/llm_engine_server/infra/repositories/__init__.py b/model-engine/model_engine_server/infra/repositories/__init__.py similarity index 57% rename from server/llm_engine_server/infra/repositories/__init__.py rename to model-engine/model_engine_server/infra/repositories/__init__.py index 061baa94..f14cf69f 100644 --- a/server/llm_engine_server/infra/repositories/__init__.py +++ b/model-engine/model_engine_server/infra/repositories/__init__.py @@ -1,31 +1,45 @@ from typing import Sequence +from .abs_file_llm_fine_tune_events_repository import ABSFileLLMFineTuneEventsRepository +from .abs_file_llm_fine_tune_repository import ABSFileLLMFineTuneRepository +from .acr_docker_repository import ACRDockerRepository from .batch_job_record_repository import BatchJobRecordRepository from .db_batch_job_record_repository import DbBatchJobRecordRepository from .db_docker_image_batch_job_bundle_repository import DbDockerImageBatchJobBundleRepository from .db_model_bundle_repository import DbModelBundleRepository from .db_model_endpoint_record_repository import DbModelEndpointRecordRepository +from .db_trigger_repository import DbTriggerRepository from .ecr_docker_repository import ECRDockerRepository +from .fake_docker_repository import FakeDockerRepository from .feature_flag_repository import FeatureFlagRepository -from .llm_fine_tuning_job_repository import LLMFineTuningJobRepository +from .live_tokenizer_repository import LiveTokenizerRepository +from .llm_fine_tune_repository import LLMFineTuneRepository from .model_endpoint_cache_repository import ModelEndpointCacheRepository from .model_endpoint_record_repository import ModelEndpointRecordRepository from .redis_feature_flag_repository import RedisFeatureFlagRepository from .redis_model_endpoint_cache_repository import RedisModelEndpointCacheRepository -from .s3_file_llm_fine_tuning_job_repository import S3FileLLMFineTuningJobRepository +from .s3_file_llm_fine_tune_events_repository import S3FileLLMFineTuneEventsRepository +from .s3_file_llm_fine_tune_repository import S3FileLLMFineTuneRepository __all__: Sequence[str] = [ + "ABSFileLLMFineTuneEventsRepository", + "ABSFileLLMFineTuneRepository", + "ACRDockerRepository", "BatchJobRecordRepository", "DbBatchJobRecordRepository", "DbDockerImageBatchJobBundleRepository", "DbModelBundleRepository", "DbModelEndpointRecordRepository", + "DbTriggerRepository", "ECRDockerRepository", + "FakeDockerRepository", "FeatureFlagRepository", - "LLMFineTuningJobRepository", + "LiveTokenizerRepository", + "LLMFineTuneRepository", "ModelEndpointRecordRepository", "ModelEndpointCacheRepository", "RedisFeatureFlagRepository", "RedisModelEndpointCacheRepository", - "S3FileLLMFineTuningJobRepository", + "S3FileLLMFineTuneRepository", + "S3FileLLMFineTuneEventsRepository", ] diff --git a/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_events_repository.py b/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_events_repository.py new file mode 100644 index 00000000..8a221c9f --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_events_repository.py @@ -0,0 +1,83 @@ +import json +import os +from json.decoder import JSONDecodeError +from typing import IO, List + +import smart_open +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneEvent +from model_engine_server.domain.exceptions import ObjectNotFoundException +from model_engine_server.domain.repositories.llm_fine_tune_events_repository import ( + LLMFineTuneEventsRepository, +) + +# Echoes llm/finetune_pipeline/docker_image_fine_tuning_entrypoint.py +ABS_HF_USER_FINE_TUNED_WEIGHTS_PREFIX = ( + f"azure://{os.getenv('ABS_CONTAINER_NAME')}/hosted-model-inference/fine_tuned_weights" +) + + +class ABSFileLLMFineTuneEventsRepository(LLMFineTuneEventsRepository): + def __init__(self): + pass + + def _open(self, uri: str, mode: str = "rt", **kwargs) -> IO: + client = BlobServiceClient( + f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net", + DefaultAzureCredential(), + ) + transport_params = {"client": client} + return smart_open.open(uri, mode, transport_params=transport_params) + + # echoes llm/finetune_pipeline/docker_image_fine_tuning_entrypoint.py + def _get_model_cache_directory_name(self, model_name: str): + """How huggingface maps model names to directory names in their cache for model files. + We adopt this when storing model cache files in ABS. + + Args: + model_name (str): Name of the huggingface model + """ + name = "models--" + model_name.replace("/", "--") + return name + + def _get_file_location(self, user_id: str, model_endpoint_name: str): + model_cache_name = self._get_model_cache_directory_name(model_endpoint_name) + abs_file_location = ( + f"{ABS_HF_USER_FINE_TUNED_WEIGHTS_PREFIX}/{user_id}/{model_cache_name}.jsonl" + ) + return abs_file_location + + async def get_fine_tune_events( + self, user_id: str, model_endpoint_name: str + ) -> List[LLMFineTuneEvent]: + abs_file_location = self._get_file_location( + user_id=user_id, model_endpoint_name=model_endpoint_name + ) + try: + with self._open(abs_file_location, "r") as f: + lines = f.readlines() + final_events = [] + for line in lines: + try: + event_dict = json.loads(line) + event = LLMFineTuneEvent( + timestamp=event_dict["timestamp"], + message=str(event_dict["message"]), + level=event_dict.get("level", "info"), + ) + except JSONDecodeError: + event = LLMFineTuneEvent( + message=line, + level="info", + ) + final_events.append(event) + return final_events + except Exception as exc: # TODO better exception + raise ObjectNotFoundException from exc + + async def initialize_events(self, user_id: str, model_endpoint_name: str) -> None: + abs_file_location = self._get_file_location( + user_id=user_id, model_endpoint_name=model_endpoint_name + ) + self._open(abs_file_location, "w") diff --git a/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_repository.py b/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_repository.py new file mode 100644 index 00000000..fc8860f2 --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/abs_file_llm_fine_tune_repository.py @@ -0,0 +1,53 @@ +import json +import os +from typing import IO, Dict, Optional + +import smart_open +from azure.identity import DefaultAzureCredential +from azure.storage.blob import BlobServiceClient +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate +from model_engine_server.infra.repositories.llm_fine_tune_repository import LLMFineTuneRepository + + +class ABSFileLLMFineTuneRepository(LLMFineTuneRepository): + def __init__(self, file_path: str): + self.file_path = file_path + + def _open(self, uri: str, mode: str = "rt", **kwargs) -> IO: + client = BlobServiceClient( + f"https://{os.getenv('ABS_ACCOUNT_NAME')}.blob.core.windows.net", + DefaultAzureCredential(), + ) + transport_params = {"client": client} + return smart_open.open(uri, mode, transport_params=transport_params) + + @staticmethod + def _get_key(model_name, fine_tuning_method): + return f"{model_name}-{fine_tuning_method}" # possible for collisions but we control these names + + async def get_job_template_for_model( + self, model_name: str, fine_tuning_method: str + ) -> Optional[LLMFineTuneTemplate]: + with self._open(self.file_path, "r") as f: + data = json.load(f) + key = self._get_key(model_name, fine_tuning_method) + job_template_dict = data.get(key, None) + if job_template_dict is None: + return None + return LLMFineTuneTemplate.parse_obj(job_template_dict) + + async def write_job_template_for_model( + self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneTemplate + ): + # Use locally in script + with self._open(self.file_path, "r") as f: + data: Dict = json.load(f) + key = self._get_key(model_name, fine_tuning_method) + data[key] = dict(job_template) + with self._open(self.file_path, "w") as f: + json.dump(data, f) + + async def initialize_data(self): + # Use locally in script + with self._open(self.file_path, "w") as f: + json.dump({}, f) diff --git a/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py b/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py new file mode 100644 index 00000000..7f9137fe --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/acr_docker_repository.py @@ -0,0 +1,47 @@ +from typing import Optional + +from azure.containerregistry import ContainerRegistryClient +from azure.core.exceptions import ResourceNotFoundError +from azure.identity import DefaultAzureCredential +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import DockerRepositoryNotFoundException +from model_engine_server.domain.repositories import DockerRepository + +logger = make_logger(logger_name()) + + +class ACRDockerRepository(DockerRepository): + def image_exists( + self, image_tag: str, repository_name: str, aws_profile: Optional[str] = None + ) -> bool: + endpoint = f"https://{infra_config().docker_repo_prefix}" + credential = DefaultAzureCredential() + client = ContainerRegistryClient(endpoint, credential) + + try: + client.get_manifest_properties(repository_name, image_tag) + except ResourceNotFoundError: + return False + return True + + def get_image_url(self, image_tag: str, repository_name: str) -> str: + return f"{infra_config().docker_repo_prefix}/{repository_name}:{image_tag}" + + def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: + raise NotImplementedError("ACR image build not supported yet") + + def get_latest_image_tag(self, repository_name: str) -> str: + endpoint = f"https://{infra_config().docker_repo_prefix}" + credential = DefaultAzureCredential() + client = ContainerRegistryClient(endpoint, credential) + + try: + image = client.list_manifest_properties( + repository_name, order_by="time_desc", results_per_page=1 + ).next() + # Azure automatically deletes empty ACR repositories, so repos will always have at least one image + return image.tags[0] + except ResourceNotFoundError: + raise DockerRepositoryNotFoundException diff --git a/server/llm_engine_server/infra/repositories/batch_job_record_repository.py b/model-engine/model_engine_server/infra/repositories/batch_job_record_repository.py similarity index 97% rename from server/llm_engine_server/infra/repositories/batch_job_record_repository.py rename to model-engine/model_engine_server/infra/repositories/batch_job_record_repository.py index 6b33ec29..982aaa5f 100644 --- a/server/llm_engine_server/infra/repositories/batch_job_record_repository.py +++ b/model-engine/model_engine_server/infra/repositories/batch_job_record_repository.py @@ -2,7 +2,7 @@ from datetime import datetime from typing import List, Optional -from llm_engine_server.domain.entities import BatchJobRecord, BatchJobStatus +from model_engine_server.domain.entities import BatchJobRecord, BatchJobStatus class BatchJobRecordRepository(ABC): diff --git a/server/llm_engine_server/infra/repositories/db_batch_job_record_repository.py b/model-engine/model_engine_server/infra/repositories/db_batch_job_record_repository.py similarity index 90% rename from server/llm_engine_server/infra/repositories/db_batch_job_record_repository.py rename to model-engine/model_engine_server/infra/repositories/db_batch_job_record_repository.py index 3ded1566..6aa9feb0 100644 --- a/server/llm_engine_server/infra/repositories/db_batch_job_record_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_batch_job_record_repository.py @@ -1,24 +1,22 @@ from datetime import datetime from typing import Any, Dict, List, Optional -from llm_engine_server.common import dict_not_none -from llm_engine_server.db.models import BatchJob as OrmBatchJob -from llm_engine_server.domain.entities import BatchJobRecord, BatchJobStatus -from llm_engine_server.infra.repositories.batch_job_record_repository import ( +from model_engine_server.common import dict_not_none +from model_engine_server.db.models import BatchJob as OrmBatchJob +from model_engine_server.domain.entities import BatchJobRecord, BatchJobStatus +from model_engine_server.infra.repositories.batch_job_record_repository import ( BatchJobRecordRepository, ) -from llm_engine_server.infra.repositories.db_model_bundle_repository import ( +from model_engine_server.infra.repositories.db_model_bundle_repository import ( translate_model_bundle_orm_to_model_bundle, ) -from llm_engine_server.infra.repositories.db_repository_mixin import ( +from model_engine_server.infra.repositories.db_repository_mixin import ( DbRepositoryMixin, raise_if_read_only, ) -def translate_batch_job_orm_to_batch_job_record( - batch_job_orm: OrmBatchJob, -) -> BatchJobRecord: +def translate_batch_job_orm_to_batch_job_record(batch_job_orm: OrmBatchJob) -> BatchJobRecord: return BatchJobRecord( id=batch_job_orm.id, created_at=batch_job_orm.created_at, diff --git a/server/llm_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py b/model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py similarity index 89% rename from server/llm_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py rename to model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py index d1774419..b97f57f4 100644 --- a/server/llm_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_docker_image_batch_job_bundle_repository.py @@ -1,21 +1,21 @@ from typing import Dict, List, Optional, Sequence -from llm_engine_server.common import dict_not_none -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle -from llm_engine_server.domain.entities import GpuType -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.common import dict_not_none +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.common.pydantic_types import ValidationError +from model_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle +from model_engine_server.domain.entities import GpuType +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) -from llm_engine_server.domain.exceptions import CorruptRecordInfraStateException -from llm_engine_server.domain.repositories.docker_image_batch_job_bundle_repository import ( +from model_engine_server.domain.exceptions import CorruptRecordInfraStateException +from model_engine_server.domain.repositories.docker_image_batch_job_bundle_repository import ( DockerImageBatchJobBundleRepository, ) -from llm_engine_server.infra.repositories.db_repository_mixin import ( +from model_engine_server.infra.repositories.db_repository_mixin import ( DbRepositoryMixin, raise_if_read_only, ) -from pydantic.error_wrappers import ValidationError class DbDockerImageBatchJobBundleRepository(DockerImageBatchJobBundleRepository, DbRepositoryMixin): diff --git a/server/llm_engine_server/infra/repositories/db_model_bundle_repository.py b/model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py similarity index 87% rename from server/llm_engine_server/infra/repositories/db_model_bundle_repository.py rename to model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py index 73700b9a..b84a598e 100644 --- a/server/llm_engine_server/infra/repositories/db_model_bundle_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_model_bundle_repository.py @@ -1,15 +1,15 @@ from typing import Any, Dict, List, Optional, Sequence -from llm_engine_server.common import dict_not_none -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.db.models import Bundle as OrmModelBundle -from llm_engine_server.domain.entities import ( +from model_engine_server.common import dict_not_none +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.db.models import Bundle as OrmModelBundle +from model_engine_server.domain.entities import ( ModelBundle, ModelBundleFlavors, ModelBundlePackagingType, ) -from llm_engine_server.domain.repositories import ModelBundleRepository -from llm_engine_server.infra.repositories.db_repository_mixin import ( +from model_engine_server.domain.repositories import ModelBundleRepository +from model_engine_server.infra.repositories.db_repository_mixin import ( DbRepositoryMixin, raise_if_read_only, ) @@ -57,7 +57,7 @@ async def create_model_bundle( ) async with self.session() as session: await OrmModelBundle.create(session, model_bundle_record) - model_bundle_record = await OrmModelBundle.select_by_id( # type: ignore + model_bundle_record = await OrmModelBundle.select_by_id( session=session, bundle_id=model_bundle_record.id ) return translate_model_bundle_orm_to_model_bundle(model_bundle_record) @@ -122,14 +122,16 @@ def translate_model_bundle_orm_to_model_bundle( flavor=model_bundle_orm.flavor, requirements=model_bundle_orm.artifact_requirements, location=model_bundle_orm.artifact_location, - framework=None - if model_bundle_orm.artifact_framework_type is None - else dict_not_none( - framework_type=model_bundle_orm.artifact_framework_type, - pytorch_image_tag=model_bundle_orm.artifact_pytorch_image_tag, - tensorflow_version=model_bundle_orm.artifact_tensorflow_version, - image_repository=model_bundle_orm.artifact_image_repository, - image_tag=model_bundle_orm.artifact_image_tag, + framework=( + None + if model_bundle_orm.artifact_framework_type is None + else dict_not_none( + framework_type=model_bundle_orm.artifact_framework_type, + pytorch_image_tag=model_bundle_orm.artifact_pytorch_image_tag, + tensorflow_version=model_bundle_orm.artifact_tensorflow_version, + image_repository=model_bundle_orm.artifact_image_repository, + image_tag=model_bundle_orm.artifact_image_tag, + ) ), app_config=model_bundle_orm.artifact_app_config, load_predict_fn=model_bundle_orm.cloudpickle_artifact_load_predict_fn, @@ -144,6 +146,9 @@ def translate_model_bundle_orm_to_model_bundle( env=model_bundle_orm.runnable_image_env, protocol=model_bundle_orm.runnable_image_protocol, readiness_initial_delay_seconds=model_bundle_orm.runnable_image_readiness_initial_delay_seconds, + extra_routes=model_bundle_orm.runnable_image_extra_routes, + worker_command=model_bundle_orm.runnable_image_worker_command, + worker_env=model_bundle_orm.runnable_image_worker_env, streaming_command=model_bundle_orm.streaming_enhanced_runnable_image_streaming_command, streaming_predict_route=model_bundle_orm.streaming_enhanced_runnable_image_streaming_predict_route, triton_model_repository=model_bundle_orm.triton_enhanced_runnable_image_model_repository, @@ -161,7 +166,7 @@ def translate_model_bundle_orm_to_model_bundle( packaging_type=model_bundle_orm.packaging_type, app_config=model_bundle_orm.app_config, ) - return ModelBundle.parse_obj(kwargs) + return ModelBundle.model_validate(kwargs) def translate_kwargs_to_model_bundle_orm( @@ -212,6 +217,9 @@ def translate_kwargs_to_model_bundle_orm( runnable_image_readiness_initial_delay_seconds=flavor_dict.get( "readiness_initial_delay_seconds" ), + runnable_image_extra_routes=flavor_dict.get("extra_routes"), + runnable_image_worker_command=flavor_dict.get("worker_command"), + runnable_image_worker_env=flavor_dict.get("worker_env"), streaming_enhanced_runnable_image_streaming_command=flavor_dict.get("streaming_command"), streaming_enhanced_runnable_image_streaming_predict_route=flavor_dict.get( "streaming_predict_route" diff --git a/server/llm_engine_server/infra/repositories/db_model_endpoint_record_repository.py b/model-engine/model_engine_server/infra/repositories/db_model_endpoint_record_repository.py similarity index 94% rename from server/llm_engine_server/infra/repositories/db_model_endpoint_record_repository.py rename to model-engine/model_engine_server/infra/repositories/db_model_endpoint_record_repository.py index b7803b3a..bfd8cab0 100644 --- a/server/llm_engine_server/infra/repositories/db_model_endpoint_record_repository.py +++ b/model-engine/model_engine_server/infra/repositories/db_model_endpoint_record_repository.py @@ -3,27 +3,27 @@ from typing import Any, Callable, Dict, List, Optional from cachetools import TTLCache -from llm_engine_server.common import dict_not_none -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.db.endpoint_row_lock import AdvisoryLockContextManager, get_lock_key -from llm_engine_server.db.models import Endpoint as OrmModelEndpoint -from llm_engine_server.domain.entities import ModelEndpointRecord -from llm_engine_server.domain.gateways import MonitoringMetricsGateway -from llm_engine_server.infra.repositories.db_model_bundle_repository import ( +from model_engine_server.common import dict_not_none +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.db.endpoint_row_lock import AdvisoryLockContextManager, get_lock_key +from model_engine_server.db.models import Endpoint as OrmModelEndpoint +from model_engine_server.domain.entities import ModelEndpointRecord +from model_engine_server.domain.gateways import MonitoringMetricsGateway +from model_engine_server.infra.repositories.db_model_bundle_repository import ( translate_model_bundle_orm_to_model_bundle, ) -from llm_engine_server.infra.repositories.db_repository_mixin import ( +from model_engine_server.infra.repositories.db_repository_mixin import ( DbRepositoryMixin, raise_if_read_only, ) -from llm_engine_server.infra.repositories.model_endpoint_record_repository import ( +from model_engine_server.infra.repositories.model_endpoint_record_repository import ( ModelEndpointRecordRepository, ) from sqlalchemy import or_, text from sqlalchemy.ext.asyncio import AsyncSession -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) CACHE_SIZE = 512 CACHE_TTL_SECONDS = 15.0 # Kubernetes caching is 15 seconds as well @@ -202,7 +202,7 @@ async def list_llm_model_endpoint_records( if owner: ownership_filters.append(OrmModelEndpoint.owner == owner) filters.append( - or_(*ownership_filters, OrmModelEndpoint.public_inference == True) # noqa + or_(*ownership_filters, OrmModelEndpoint.public_inference == True) # noqa: E712 ) async with self.session() as session: diff --git a/server/llm_engine_server/infra/repositories/db_repository_mixin.py b/model-engine/model_engine_server/infra/repositories/db_repository_mixin.py similarity index 85% rename from server/llm_engine_server/infra/repositories/db_repository_mixin.py rename to model-engine/model_engine_server/infra/repositories/db_repository_mixin.py index f1d26a81..e0e9f242 100644 --- a/server/llm_engine_server/infra/repositories/db_repository_mixin.py +++ b/model-engine/model_engine_server/infra/repositories/db_repository_mixin.py @@ -2,7 +2,7 @@ from functools import wraps from typing import Callable -from llm_engine_server.core.domain_exceptions import ReadOnlyDatabaseException +from model_engine_server.domain.exceptions import ReadOnlyDatabaseException from sqlalchemy.ext.asyncio import AsyncSession diff --git a/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py b/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py new file mode 100644 index 00000000..367942f9 --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/db_trigger_repository.py @@ -0,0 +1,134 @@ +from typing import Any, Dict, Optional, Sequence + +from model_engine_server.common import dict_not_none +from model_engine_server.common.pydantic_types import ValidationError +from model_engine_server.db.models import Trigger as OrmTrigger +from model_engine_server.domain.entities.trigger_entity import Trigger +from model_engine_server.domain.exceptions import ( + CorruptRecordInfraStateException, + TriggerNameAlreadyExistsException, +) +from model_engine_server.domain.repositories.trigger_repository import TriggerRepository +from model_engine_server.infra.repositories.db_repository_mixin import ( + DbRepositoryMixin, + raise_if_read_only, +) +from sqlalchemy.exc import IntegrityError + + +class DbTriggerRepository(TriggerRepository, DbRepositoryMixin): + @raise_if_read_only + async def create_trigger( + self, + *, + name: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Optional[Dict[str, str]], + ) -> Trigger: + trigger_record = translate_kwargs_to_trigger_orm( + name=name, + created_by=created_by, + owner=owner, + cron_schedule=cron_schedule, + docker_image_batch_job_bundle_id=docker_image_batch_job_bundle_id, + default_job_config=default_job_config, + default_job_metadata=default_job_metadata, + ) + try: + async with self.session() as session: + await OrmTrigger.create(session, trigger_record) + trigger_record = await OrmTrigger.select_by_id( + session=session, trigger_id=trigger_record.id + ) + except IntegrityError: + raise TriggerNameAlreadyExistsException( + f"Trigger with name {name} already exists for {owner}" + ) + return translate_trigger_orm_to_entity(trigger_record) + + async def list_triggers(self, owner: str) -> Sequence[Trigger]: + async with self.session() as session: + trigger_records = await OrmTrigger.select_all_by_owner(session=session, owner=owner) + triggers = [translate_trigger_orm_to_entity(tr) for tr in trigger_records] + return triggers + + async def get_trigger(self, trigger_id: str) -> Optional[Trigger]: + async with self.session() as session: + trigger_record = await OrmTrigger.select_by_id(session=session, trigger_id=trigger_id) + if not trigger_record: + return None + + return translate_trigger_orm_to_entity(trigger_record) + + @raise_if_read_only + async def update_trigger( + self, + trigger_id: str, + cron_schedule: str, + ) -> bool: + async with self.session() as session: + trigger = await OrmTrigger.select_by_id(session=session, trigger_id=trigger_id) + if trigger is None: + return False + + await OrmTrigger.update_by_id( + session=session, trigger_id=trigger_id, kwargs=dict(cron_schedule=cron_schedule) + ) + return True + + @raise_if_read_only + async def delete_trigger( + self, + trigger_id: str, + ) -> bool: + async with self.session() as session: + trigger = await OrmTrigger.select_by_id(session=session, trigger_id=trigger_id) + if trigger is None: + return False + + await OrmTrigger.delete_by_id(session=session, trigger_id=trigger_id) + return True + + +def translate_trigger_orm_to_entity( + trigger_orm: OrmTrigger, +) -> Trigger: + kwargs = dict_not_none( + id=trigger_orm.id, + name=trigger_orm.name, + owner=trigger_orm.owner, + created_at=trigger_orm.created_at, + created_by=trigger_orm.created_by, + cron_schedule=trigger_orm.cron_schedule, + docker_image_batch_job_bundle_id=trigger_orm.docker_image_batch_job_bundle_id, + default_job_config=trigger_orm.default_job_config, + default_job_metadata=trigger_orm.default_job_metadata, + ) + try: + return Trigger.parse_obj(kwargs) + except ValidationError as exc: + raise CorruptRecordInfraStateException() from exc + + +def translate_kwargs_to_trigger_orm( + name: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Optional[Dict[str, str]], +) -> OrmTrigger: + return OrmTrigger( + name=name, + owner=owner, + created_by=created_by, + cron_schedule=cron_schedule, + docker_image_batch_job_bundle_id=docker_image_batch_job_bundle_id, + default_job_config=default_job_config, + default_job_metadata=default_job_metadata, + ) diff --git a/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py new file mode 100644 index 00000000..d283c4c4 --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/ecr_docker_repository.py @@ -0,0 +1,58 @@ +from typing import Optional + +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.core.config import infra_config +from model_engine_server.core.docker.ecr import get_latest_image_tag +from model_engine_server.core.docker.ecr import image_exists as ecr_image_exists +from model_engine_server.core.docker.remote_build import build_remote_block +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.repositories import DockerRepository + +logger = make_logger(logger_name()) + + +class ECRDockerRepository(DockerRepository): + def image_exists( + self, image_tag: str, repository_name: str, aws_profile: Optional[str] = None + ) -> bool: + return ecr_image_exists( + image_tag=image_tag, + repository_name=repository_name, + aws_profile=aws_profile, + ) + + def get_image_url(self, image_tag: str, repository_name: str) -> str: + return f"{infra_config().docker_repo_prefix}/{repository_name}:{image_tag}" + + def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: + logger.info(f"build_image args {locals()}") + folders_to_include = ["model-engine"] + if image_params.requirements_folder: + folders_to_include.append(image_params.requirements_folder) + + dockerfile_root_folder = image_params.dockerfile.split("/")[0] + if dockerfile_root_folder not in folders_to_include: + folders_to_include.append(dockerfile_root_folder) + + build_args = { + "BASE_IMAGE": image_params.base_image, + } + + if image_params.substitution_args: + build_args.update(image_params.substitution_args) + + build_result = build_remote_block( + context=image_params.base_path, + dockerfile=image_params.dockerfile, + repotags=[f"{image_params.repo}:{image_params.image_tag}"], + folders_to_include=folders_to_include, + build_args=build_args, + cache_name=hmi_config.docker_image_layer_cache_repository, + ) + return BuildImageResponse( + status=build_result.status, logs=build_result.logs, job_name=build_result.job_name + ) + + def get_latest_image_tag(self, repository_name: str) -> str: + return get_latest_image_tag(repository_name=repository_name) diff --git a/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py b/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py new file mode 100644 index 00000000..2d12de6e --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/fake_docker_repository.py @@ -0,0 +1,24 @@ +from typing import Optional + +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.repositories import DockerRepository + +logger = make_logger(logger_name()) + + +class FakeDockerRepository(DockerRepository): + def image_exists( + self, image_tag: str, repository_name: str, aws_profile: Optional[str] = None + ) -> bool: + return True + + def get_image_url(self, image_tag: str, repository_name: str) -> str: + return f"{infra_config().docker_repo_prefix}/{repository_name}:{image_tag}" + + def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: + raise NotImplementedError("FakeDockerRepository build_image() not implemented") + + def get_latest_image_tag(self, repository_name: str) -> str: + raise NotImplementedError("FakeDockerRepository get_latest_image_tag() not implemented") diff --git a/server/llm_engine_server/infra/repositories/feature_flag_repository.py b/model-engine/model_engine_server/infra/repositories/feature_flag_repository.py similarity index 100% rename from server/llm_engine_server/infra/repositories/feature_flag_repository.py rename to model-engine/model_engine_server/infra/repositories/feature_flag_repository.py diff --git a/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py new file mode 100644 index 00000000..54e6436c --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/live_tokenizer_repository.py @@ -0,0 +1,189 @@ +import os +from functools import lru_cache +from typing import Dict, NamedTuple, Optional + +from huggingface_hub import list_repo_refs +from huggingface_hub.utils._errors import RepositoryNotFoundError +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.exceptions import ObjectNotFoundException +from model_engine_server.domain.gateways.llm_artifact_gateway import LLMArtifactGateway +from model_engine_server.domain.repositories.tokenizer_repository import TokenizerRepository +from transformers import AutoTokenizer + +logger = make_logger(logger_name()) + + +TOKENIZER_FILES_REQUIRED = [ + "tokenizer.json", + "tokenizer_config.json", + "special_tokens_map.json", +] +TOKENIZER_FILES_OPTIONAL = [ + "tokenizer.model", +] +TOKENIZER_TARGET_DIR = "/opt/.cache/model_engine_server/tokenizers" + + +class ModelInfo(NamedTuple): + hf_repo: str + s3_repo: Optional[str] + + +def get_default_supported_models_info() -> Dict[str, ModelInfo]: + return { + "mpt-7b": ModelInfo("mosaicml/mpt-7b", None), + "mpt-7b-instruct": ModelInfo("mosaicml/mpt-7b-instruct", None), + "flan-t5-xxl": ModelInfo("google/flan-t5-xxl", None), + "llama-7b": ModelInfo("decapoda-research/llama-7b-hf", None), + "llama-2-7b": ModelInfo("huggyllama/llama-7b", None), + "llama-2-7b-chat": ModelInfo("meta-llama/Llama-2-7b-chat-hf", None), + "llama-2-13b": ModelInfo("meta-llama/Llama-2-13b-hf", None), + "llama-2-13b-chat": ModelInfo("meta-llama/Llama-2-13b-chat-hf", None), + "llama-2-70b": ModelInfo("meta-llama/Llama-2-70b-hf", None), + "llama-2-70b-chat": ModelInfo("meta-llama/Llama-2-70b-chat-hf", None), + "llama-3-8b": ModelInfo("meta-llama/Meta-Llama-3-8B", None), + "llama-3-8b-instruct": ModelInfo("meta-llama/Meta-Llama-3-8B-Instruct", None), + "llama-3-8b-instruct-262k": ModelInfo("gradientai/Llama-3-8B-Instruct-262k", None), + "llama-3-70b": ModelInfo("meta-llama/Meta-Llama-3-70B", None), + "llama-3-70b-instruct": ModelInfo("meta-llama/Meta-Llama-3-70B-Instruct", None), + "llama-3-1-8b": ModelInfo("meta-llama/Meta-Llama-3.1-8B", None), + "llama-3-1-8b-instruct": ModelInfo("meta-llama/Meta-Llama-3.1-8B-Instruct", None), + "llama-3-1-70b": ModelInfo("meta-llama/Meta-Llama-3.1-70B", None), + "llama-3-1-70b-instruct": ModelInfo("meta-llama/Meta-Llama-3.1-70B-Instruct", None), + "llama-3-1-405b": ModelInfo("meta-llama/Meta-Llama-3.1-405B", None), + "llama-3-1-405b-instruct": ModelInfo("meta-llama/Meta-Llama-3.1-405B-Instruct", None), + "falcon-7b": ModelInfo("tiiuae/falcon-7b", None), + "falcon-7b-instruct": ModelInfo("tiiuae/falcon-7b-instruct", None), + "falcon-40b": ModelInfo("tiiuae/falcon-40b", None), + "falcon-40b-instruct": ModelInfo("tiiuae/falcon-40b-instruct", None), + "falcon-180b": ModelInfo("tiiuae/falcon-180B", None), + "falcon-180b-chat": ModelInfo("tiiuae/falcon-180B-chat", None), + "codellama-7b": ModelInfo("codellama/CodeLlama-7b-hf", None), + "codellama-7b-instruct": ModelInfo("codellama/CodeLlama-7b-Instruct-hf", None), + "codellama-13b": ModelInfo("codellama/CodeLlama-13b-hf", None), + "codellama-13b-instruct": ModelInfo("codellama/CodeLlama-13b-Instruct-hf", None), + "codellama-34b": ModelInfo("codellama/CodeLlama-34b-hf", None), + "codellama-34b-instruct": ModelInfo("codellama/CodeLlama-34b-Instruct-hf", None), + "codellama-70b": ModelInfo("codellama/CodeLlama-70b-hf", None), + "codellama-70b-instruct": ModelInfo("codellama/CodeLlama-70b-Instruct-hf", None), + "llm-jp-13b-instruct-full": ModelInfo("llm-jp/llm-jp-13b-instruct-full-jaster-v1.0", None), + "llm-jp-13b-instruct-full-dolly": ModelInfo( + "llm-jp/llm-jp-13b-instruct-full-dolly-oasst-v1.0", None + ), + "mistral-7b": ModelInfo("mistralai/Mistral-7B-v0.1", None), + "mistral-7b-instruct": ModelInfo("mistralai/Mistral-7B-Instruct-v0.1", None), + "mixtral-8x7b": ModelInfo("mistralai/Mixtral-8x7B-v0.1", None), + "mixtral-8x7b-instruct": ModelInfo("mistralai/Mixtral-8x7B-Instruct-v0.1", None), + "mixtral-8x22b": ModelInfo("mistralai/Mixtral-8x22B-v0.1", None), + "mixtral-8x22b-instruct": ModelInfo("mistralai/Mixtral-8x22B-Instruct-v0.1", None), + "mammoth-coder-llama-2-7b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-7B", None), + "mammoth-coder-llama-2-13b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-13B", None), + "mammoth-coder-llama-2-34b": ModelInfo("TIGER-Lab/MAmmoTH-Coder-34B", None), + "gpt-j-6b": ModelInfo("EleutherAI/gpt-j-6b", None), + "gpt-j-6b-zh-en": ModelInfo("EleutherAI/gpt-j-6b", None), + "gpt4all-j": ModelInfo("nomic-ai/gpt4all-j", None), + "dolly-v2-12b": ModelInfo("databricks/dolly-v2-12b", None), + "stablelm-tuned-7b": ModelInfo("StabilityAI/stablelm-tuned-alpha-7b", None), + "vicuna-13b": ModelInfo("eachadea/vicuna-13b-1.1", None), + "zephyr-7b-alpha": ModelInfo("HuggingFaceH4/zephyr-7b-alpha", None), + "zephyr-7b-beta": ModelInfo("HuggingFaceH4/zephyr-7b-beta", None), + "gemma-2-2b": ModelInfo("google/gemma-2-2b", None), + "gemma-2-2b-instruct": ModelInfo("google/gemma-2-2b-it", None), + "gemma-2-7b": ModelInfo("google/gemma-2-7b", None), + "gemma-2-7b-instruct": ModelInfo("google/gemma-2-7b-it", None), + "phi-3-mini-4k-instruct": ModelInfo("microsoft/phi-3-mini-4k-instruct", None), + "phi-3-mini-128k-instruct": ModelInfo("microsoft/phi-3-mini-128k-instruct", None), + "phi-3-small-8k-instruct": ModelInfo("microsoft/phi-3-small-8k-instruct", None), + "phi-3-small-128k-instruct": ModelInfo("microsoft/phi-3-small-128k-instruct", None), + "phi-3-medium-4-instruct": ModelInfo("microsoft/phi-3-medium-4k-instruct", None), + "phi-3-medium-128k-instruct": ModelInfo("microsoft/phi-3-medium-128k-instruct", None), + "deepseek-coder-v2": ModelInfo("deepseek-ai/DeepSeek-Coder-V2-Base", None), + "deepseek-coder-v2-instruct": ModelInfo("deepseek-ai/DeepSeek-Coder-V2-Instruct", None), + "deepseek-coder-v2-lite": ModelInfo("deepseek-ai/DeepSeek-Coder-V2-Lite-Base", None), + "deepseek-coder-v2-lite-instruct": ModelInfo( + "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", None + ), + "qwen2-72b-instruct": ModelInfo( + "Qwen/Qwen2-72B-Instruct", + None, + ), + } + + +def get_supported_models_info() -> Dict[str, ModelInfo]: + try: + from plugins.live_tokenizer_repository import ( + get_supported_models_info as get_custom_supported_models_info, + ) + + return get_custom_supported_models_info() + except ModuleNotFoundError: + return get_default_supported_models_info() + + +SUPPORTED_MODELS_INFO = get_supported_models_info() + + +def get_models_s3_uri(*args, **kwargs) -> str: + try: + from plugins.live_tokenizer_repository import get_models_s3_uri as get_custom_models_s3_uri + + return get_custom_models_s3_uri(*args, **kwargs) + except ModuleNotFoundError: + raise NotImplementedError + + +def get_models_local_dir_path(model_name: str) -> str: + """ + Get the local directory path for a given model. + """ + return f"{TOKENIZER_TARGET_DIR}/{model_name}" + + +class LiveTokenizerRepository(TokenizerRepository): + def __init__(self, llm_artifact_gateway: LLMArtifactGateway): + self.llm_artifact_gateway = llm_artifact_gateway + + def _load_tokenizer_from_s3(self, model_name: str, s3_prefix: Optional[str]) -> Optional[str]: + """ + Download tokenizer files from S3 to the local filesystem. + """ + if not s3_prefix: + return None + + model_tokenizer_dir = get_models_local_dir_path(model_name) + + for file in TOKENIZER_FILES_REQUIRED: + s3_path = get_models_s3_uri(s3_prefix, file) + target_path = os.path.join(model_tokenizer_dir, file) + self.llm_artifact_gateway.download_files(s3_path, target_path) + + for file in TOKENIZER_FILES_OPTIONAL: + s3_path = get_models_s3_uri(s3_prefix, file) + target_path = os.path.join(model_tokenizer_dir, file) + try: + self.llm_artifact_gateway.download_files(s3_path, target_path) + except Exception: + pass + + return model_tokenizer_dir + + @lru_cache(maxsize=32) + def load_tokenizer(self, model_name: str) -> AutoTokenizer: + model_info = SUPPORTED_MODELS_INFO[model_name] + + model_location = None + try: + if not model_info.hf_repo: + raise RepositoryNotFoundError("No HF repo specified for model.") + list_repo_refs(model_info.hf_repo) # check if model exists in Hugging Face Hub + model_location = model_info.hf_repo + # AutoTokenizer handles file downloads for HF repos + except RepositoryNotFoundError: + model_location = self._load_tokenizer_from_s3(model_name, model_info.s3_repo) + + if not model_location: + raise ObjectNotFoundException(f"Tokenizer not found for model {model_name}.") + + logger.info(f"Loading tokenizer for model {model_name} from {model_location}.") + return AutoTokenizer.from_pretrained(model_location) diff --git a/server/llm_engine_server/infra/repositories/llm_fine_tuning_job_repository.py b/model-engine/model_engine_server/infra/repositories/llm_fine_tune_repository.py similarity index 69% rename from server/llm_engine_server/infra/repositories/llm_fine_tuning_job_repository.py rename to model-engine/model_engine_server/infra/repositories/llm_fine_tune_repository.py index 4a7acc1e..b33d74de 100644 --- a/server/llm_engine_server/infra/repositories/llm_fine_tuning_job_repository.py +++ b/model-engine/model_engine_server/infra/repositories/llm_fine_tune_repository.py @@ -1,10 +1,10 @@ from abc import ABC, abstractmethod from typing import Optional -from llm_engine_server.domain.entities.llm_fine_tune_job_entity import LLMFineTuneJobTemplate +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate -class LLMFineTuningJobRepository(ABC): +class LLMFineTuneRepository(ABC): """ Basically a store of model name + fine tuning method -> docker image batch job bundle ids @@ -13,11 +13,11 @@ class LLMFineTuningJobRepository(ABC): @abstractmethod async def get_job_template_for_model( self, model_name: str, fine_tuning_method: str - ) -> Optional[LLMFineTuneJobTemplate]: + ) -> Optional[LLMFineTuneTemplate]: pass @abstractmethod async def write_job_template_for_model( - self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneJobTemplate + self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneTemplate ): pass diff --git a/server/llm_engine_server/infra/repositories/model_endpoint_cache_repository.py b/model-engine/model_engine_server/infra/repositories/model_endpoint_cache_repository.py similarity index 93% rename from server/llm_engine_server/infra/repositories/model_endpoint_cache_repository.py rename to model-engine/model_engine_server/infra/repositories/model_endpoint_cache_repository.py index 2d8a22a9..3c26cbfd 100644 --- a/server/llm_engine_server/infra/repositories/model_endpoint_cache_repository.py +++ b/model-engine/model_engine_server/infra/repositories/model_endpoint_cache_repository.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from typing import Optional -from llm_engine_server.domain.entities import ModelEndpointInfraState +from model_engine_server.domain.entities import ModelEndpointInfraState class ModelEndpointCacheRepository(ABC): diff --git a/server/llm_engine_server/infra/repositories/model_endpoint_record_repository.py b/model-engine/model_engine_server/infra/repositories/model_endpoint_record_repository.py similarity index 97% rename from server/llm_engine_server/infra/repositories/model_endpoint_record_repository.py rename to model-engine/model_engine_server/infra/repositories/model_endpoint_record_repository.py index 48a222b3..3abaee21 100644 --- a/server/llm_engine_server/infra/repositories/model_endpoint_record_repository.py +++ b/model-engine/model_engine_server/infra/repositories/model_endpoint_record_repository.py @@ -2,8 +2,8 @@ from contextlib import AbstractAsyncContextManager from typing import Any, Dict, List, Optional, Sequence -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.domain.entities import ModelEndpointRecord +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.domain.entities import ModelEndpointRecord __all__: Sequence[str] = ("ModelEndpointRecordRepository",) diff --git a/server/llm_engine_server/infra/repositories/redis_feature_flag_repository.py b/model-engine/model_engine_server/infra/repositories/redis_feature_flag_repository.py similarity index 92% rename from server/llm_engine_server/infra/repositories/redis_feature_flag_repository.py rename to model-engine/model_engine_server/infra/repositories/redis_feature_flag_repository.py index b40c1c30..8283ab23 100644 --- a/server/llm_engine_server/infra/repositories/redis_feature_flag_repository.py +++ b/model-engine/model_engine_server/infra/repositories/redis_feature_flag_repository.py @@ -1,7 +1,7 @@ from typing import Optional import aioredis -from llm_engine_server.infra.repositories.feature_flag_repository import FeatureFlagRepository +from model_engine_server.infra.repositories.feature_flag_repository import FeatureFlagRepository class RedisFeatureFlagRepository(FeatureFlagRepository): @@ -27,7 +27,7 @@ def __init__( @staticmethod def _to_redis_key(key: str): - return f"llm-engine-feature-flag:{key}" + return f"launch-feature-flag:{key}" async def write_feature_flag_bool(self, key: str, value: bool): if not isinstance(value, bool): diff --git a/server/llm_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py b/model-engine/model_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py similarity index 84% rename from server/llm_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py rename to model-engine/model_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py index 7f2dde7a..feea00b7 100644 --- a/server/llm_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py +++ b/model-engine/model_engine_server/infra/repositories/redis_model_endpoint_cache_repository.py @@ -1,12 +1,15 @@ import json +import os from typing import Optional import aioredis -from llm_engine_server.domain.entities import ModelEndpointInfraState -from llm_engine_server.infra.repositories.model_endpoint_cache_repository import ( +from model_engine_server.domain.entities import ModelEndpointInfraState +from model_engine_server.infra.repositories.model_endpoint_cache_repository import ( ModelEndpointCacheRepository, ) +SERVICE_IDENTIFIER = os.getenv("SERVICE_IDENTIFIER") + class RedisModelEndpointCacheRepository(ModelEndpointCacheRepository): # TODO figure out exceptions that can be thrown @@ -32,7 +35,10 @@ def __init__( @staticmethod def _find_redis_key(key: str): - return f"llm-engine-k8s-cache:{key}" + if SERVICE_IDENTIFIER: + return f"launch-k8s-cache:{SERVICE_IDENTIFIER}:{key}" + else: + return f"launch-k8s-cache:{key}" async def write_endpoint_info( self, diff --git a/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py new file mode 100644 index 00000000..2dfcbc76 --- /dev/null +++ b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_events_repository.py @@ -0,0 +1,89 @@ +import json +import os +from json.decoder import JSONDecodeError +from typing import IO, List + +import boto3 +import smart_open +from model_engine_server.core.config import infra_config +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneEvent +from model_engine_server.domain.exceptions import ObjectNotFoundException +from model_engine_server.domain.repositories.llm_fine_tune_events_repository import ( + LLMFineTuneEventsRepository, +) + +# Echoes llm/finetune_pipeline/docker_image_fine_tuning_entrypoint.py +S3_HF_USER_FINE_TUNED_WEIGHTS_PREFIX = ( + f"s3://{infra_config().s3_bucket}/hosted-model-inference/fine_tuned_weights" +) + + +class S3FileLLMFineTuneEventsRepository(LLMFineTuneEventsRepository): + def __init__(self): + pass + + # _get_s3_client + _open copypasted from s3_file_llm_fine_tune_repo, in turn from s3_filesystem_gateway + # sorry + def _get_s3_client(self, kwargs): + profile_name = kwargs.get("aws_profile", os.getenv("S3_WRITE_AWS_PROFILE")) + session = boto3.Session(profile_name=profile_name) + client = session.client("s3") + return client + + def _open(self, uri: str, mode: str = "rt", **kwargs) -> IO: + # This follows the 5.1.0 smart_open API + client = self._get_s3_client(kwargs) + transport_params = {"client": client} + return smart_open.open(uri, mode, transport_params=transport_params) + + # echoes llm/finetune_pipeline/docker_image_fine_tuning_entrypoint.py + def _get_model_cache_directory_name(self, model_name: str): + """How huggingface maps model names to directory names in their cache for model files. + We adopt this when storing model cache files in s3. + + Args: + model_name (str): Name of the huggingface model + """ + name = "models--" + model_name.replace("/", "--") + return name + + def _get_file_location(self, user_id: str, model_endpoint_name: str): + model_cache_name = self._get_model_cache_directory_name(model_endpoint_name) + s3_file_location = ( + f"{S3_HF_USER_FINE_TUNED_WEIGHTS_PREFIX}/{user_id}/{model_cache_name}.jsonl" + ) + return s3_file_location + + async def get_fine_tune_events( + self, user_id: str, model_endpoint_name: str + ) -> List[LLMFineTuneEvent]: + s3_file_location = self._get_file_location( + user_id=user_id, model_endpoint_name=model_endpoint_name + ) + try: + with self._open(s3_file_location, "r") as f: + lines = f.readlines() + final_events = [] + for line in lines: + try: + event_dict = json.loads(line) + event = LLMFineTuneEvent( + timestamp=event_dict["timestamp"], + message=str(event_dict["message"]), + level=event_dict.get("level", "info"), + ) + except JSONDecodeError: + event = LLMFineTuneEvent( + message=line, + level="info", + ) + final_events.append(event) + return final_events + except Exception as exc: # TODO better exception + raise ObjectNotFoundException from exc + + async def initialize_events(self, user_id: str, model_endpoint_name: str) -> None: + s3_file_location = self._get_file_location( + user_id=user_id, model_endpoint_name=model_endpoint_name + ) + self._open(s3_file_location, "w") diff --git a/server/llm_engine_server/infra/repositories/s3_file_llm_fine_tuning_job_repository.py b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_repository.py similarity index 81% rename from server/llm_engine_server/infra/repositories/s3_file_llm_fine_tuning_job_repository.py rename to model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_repository.py index e651dcd2..6b3ea8aa 100644 --- a/server/llm_engine_server/infra/repositories/s3_file_llm_fine_tuning_job_repository.py +++ b/model-engine/model_engine_server/infra/repositories/s3_file_llm_fine_tune_repository.py @@ -4,13 +4,11 @@ import boto3 import smart_open -from llm_engine_server.domain.entities.llm_fine_tune_job_entity import LLMFineTuneJobTemplate -from llm_engine_server.infra.repositories.llm_fine_tuning_job_repository import ( - LLMFineTuningJobRepository, -) +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate +from model_engine_server.infra.repositories.llm_fine_tune_repository import LLMFineTuneRepository -class S3FileLLMFineTuningJobRepository(LLMFineTuningJobRepository): +class S3FileLLMFineTuneRepository(LLMFineTuneRepository): def __init__(self, file_path: str): self.file_path = file_path @@ -32,7 +30,7 @@ def _get_key(model_name, fine_tuning_method): async def get_job_template_for_model( self, model_name: str, fine_tuning_method: str - ) -> Optional[LLMFineTuneJobTemplate]: + ) -> Optional[LLMFineTuneTemplate]: # can hot reload the file lol with self._open(self.file_path, "r") as f: data = json.load(f) @@ -40,10 +38,10 @@ async def get_job_template_for_model( job_template_dict = data.get(key, None) if job_template_dict is None: return None - return LLMFineTuneJobTemplate.parse_obj(job_template_dict) + return LLMFineTuneTemplate.parse_obj(job_template_dict) async def write_job_template_for_model( - self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneJobTemplate + self, model_name: str, fine_tuning_method: str, job_template: LLMFineTuneTemplate ): # Use locally in script with self._open(self.file_path, "r") as f: diff --git a/server/llm_engine_server/infra/services/__init__.py b/model-engine/model_engine_server/infra/services/__init__.py similarity index 100% rename from server/llm_engine_server/infra/services/__init__.py rename to model-engine/model_engine_server/infra/services/__init__.py diff --git a/server/llm_engine_server/infra/services/batch_job_orchestration_service.py b/model-engine/model_engine_server/infra/services/batch_job_orchestration_service.py similarity index 91% rename from server/llm_engine_server/infra/services/batch_job_orchestration_service.py rename to model-engine/model_engine_server/infra/services/batch_job_orchestration_service.py index bba6d661..bbfa54af 100644 --- a/server/llm_engine_server/infra/services/batch_job_orchestration_service.py +++ b/model-engine/model_engine_server/infra/services/batch_job_orchestration_service.py @@ -1,7 +1,7 @@ from abc import ABC, abstractmethod from datetime import timedelta -from llm_engine_server.domain.entities import BatchJobSerializationFormat +from model_engine_server.domain.entities import BatchJobSerializationFormat class BatchJobOrchestrationService(ABC): diff --git a/server/llm_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py similarity index 51% rename from server/llm_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py rename to model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py index 7773182f..d35ef21a 100644 --- a/server/llm_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py +++ b/model-engine/model_engine_server/infra/services/docker_image_batch_job_llm_fine_tuning_service.py @@ -1,22 +1,25 @@ import os -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional -from llm_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests -from llm_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob -from llm_engine_server.domain.exceptions import ( +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.entities import FineTuneHparamValueType +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob +from model_engine_server.domain.exceptions import ( InvalidRequestException, LLMFineTuningMethodNotImplementedException, ) -from llm_engine_server.domain.gateways.docker_image_batch_job_gateway import ( +from model_engine_server.domain.gateways.docker_image_batch_job_gateway import ( DockerImageBatchJobGateway, ) -from llm_engine_server.domain.repositories.docker_image_batch_job_bundle_repository import ( +from model_engine_server.domain.repositories.docker_image_batch_job_bundle_repository import ( DockerImageBatchJobBundleRepository, ) -from llm_engine_server.domain.services.llm_fine_tuning_service import LLMFineTuningService -from llm_engine_server.infra.repositories.llm_fine_tuning_job_repository import ( - LLMFineTuningJobRepository, -) +from model_engine_server.domain.services import LLMFineTuningService +from model_engine_server.infra.repositories.llm_fine_tune_repository import LLMFineTuneRepository + +logger = make_logger(logger_name()) class DockerImageBatchJobLLMFineTuningService(LLMFineTuningService): @@ -24,30 +27,33 @@ def __init__( self, docker_image_batch_job_gateway: DockerImageBatchJobGateway, docker_image_batch_job_bundle_repo: DockerImageBatchJobBundleRepository, - llm_fine_tuning_job_repository: LLMFineTuningJobRepository, + llm_fine_tune_repository: LLMFineTuneRepository, ): self.docker_image_batch_job_gateway = docker_image_batch_job_gateway self.docker_image_batch_job_bundle_repo = docker_image_batch_job_bundle_repo - self.llm_fine_tuning_job_repository = llm_fine_tuning_job_repository + self.llm_fine_tune_repository = llm_fine_tune_repository - async def create_fine_tune_job( + async def create_fine_tune( self, created_by: str, owner: str, + model: str, training_file: str, - validation_file: str, - model_name: str, - base_model: str, + validation_file: Optional[str], fine_tuning_method: str, - hyperparameters: Dict[str, str], + hyperparameters: Dict[str, FineTuneHparamValueType], + fine_tuned_model: str, + wandb_config: Optional[Dict[str, Any]], ) -> str: - batch_job_template = await self.llm_fine_tuning_job_repository.get_job_template_for_model( - model_name=base_model, fine_tuning_method=fine_tuning_method + # fine_tuned_model must be a valid k8s label. Leaky implementation detail unfortunately. + batch_job_template = await self.llm_fine_tune_repository.get_job_template_for_model( + model_name=model, fine_tuning_method=fine_tuning_method ) if batch_job_template is None: raise LLMFineTuningMethodNotImplementedException( - f"Fine-tuning not implemented for the (base model, fine-tuning method) pairing of ({base_model}, {fine_tuning_method})" - ) + f"Fine-tuning not implemented for model type {model}" + # f"Fine-tuning not implemented for the (base model, fine-tuning method) pairing of ({base_model}, {fine_tuning_method})" + ) # TODO uncomment out error when we support multiple fine tuning methods for param in batch_job_template.required_params: if param not in hyperparameters: @@ -66,23 +72,36 @@ async def create_fine_tune_job( ) if di_batch_job_bundle is None: - raise LLMFineTuningMethodNotImplementedException("Fine-tuning job doesn't exist") + raise LLMFineTuningMethodNotImplementedException("Fine-tuning method doesn't exist") if not di_batch_job_bundle.public and di_batch_job_bundle.owner != owner: - raise LLMFineTuningMethodNotImplementedException("Fine-tuning job not accessible") + raise LLMFineTuningMethodNotImplementedException("Fine-tuning method not accessible") + # TODO: Pass user-defined labels + labels = dict(team="egp", product="training.llm_engine_fine_tune") + + logger.info( + f"Using bundle {di_batch_job_bundle.id} for fine-tune job: {di_batch_job_bundle.image_repository=}, {di_batch_job_bundle.image_tag=}" + ) batch_job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( created_by=created_by, owner=owner, job_config=dict( + **labels, gateway_url=os.getenv("GATEWAY_URL"), + cloud_provider=infra_config().cloud_provider, + aws_profile=infra_config().profile_ml_worker, + s3_bucket=infra_config().s3_bucket, + azure_client_id=os.getenv("AZURE_CLIENT_ID"), + abs_account_name=os.getenv("ABS_ACCOUNT_NAME"), + abs_container_name=os.getenv("ABS_CONTAINER_NAME"), user_id=owner, training_file=training_file, validation_file=validation_file, - model_name=model_name, - launch_bundle_config=batch_job_template.launch_bundle_config, + model_name=fine_tuned_model, launch_endpoint_config=batch_job_template.launch_endpoint_config, hyperparameters=combined_hyperparameters, + wandb_config=wandb_config, ), env=di_batch_job_bundle.env, command=di_batch_job_bundle.command, @@ -95,15 +114,14 @@ async def create_fine_tune_job( gpu_type=di_batch_job_bundle.gpu_type, storage=di_batch_job_bundle.storage, ), - labels=dict(team="infra", product="llm-fine-tuning"), + labels=labels, + annotations=dict(fine_tuned_model=fine_tuned_model), mount_location=di_batch_job_bundle.mount_location, ) return batch_job_id - async def get_fine_tune_job( - self, owner: str, fine_tune_id: str - ) -> Optional[DockerImageBatchJob]: + async def get_fine_tune(self, owner: str, fine_tune_id: str) -> Optional[DockerImageBatchJob]: di_batch_job = await self.docker_image_batch_job_gateway.get_docker_image_batch_job( batch_job_id=fine_tune_id ) @@ -111,17 +129,25 @@ async def get_fine_tune_job( return None return di_batch_job - async def list_fine_tune_jobs(self, owner: str) -> List[DockerImageBatchJob]: + async def list_fine_tunes(self, owner: str) -> List[DockerImageBatchJob]: di_batch_jobs = await self.docker_image_batch_job_gateway.list_docker_image_batch_jobs( owner=owner ) return di_batch_jobs - async def cancel_fine_tune_job(self, owner: str, fine_tune_id: str) -> bool: - di_batch_job = self.get_fine_tune_job(owner, fine_tune_id) + async def cancel_fine_tune(self, owner: str, fine_tune_id: str) -> bool: + di_batch_job = await self.get_fine_tune(owner, fine_tune_id) if di_batch_job is None: return False cancel = await self.docker_image_batch_job_gateway.update_docker_image_batch_job( batch_job_id=fine_tune_id, cancel=True ) return cancel + + async def get_fine_tune_model_name_from_id( + self, owner: str, fine_tune_id: str + ) -> Optional[str]: + di_batch_job = await self.get_fine_tune(owner, fine_tune_id) + if di_batch_job is None or di_batch_job.annotations is None: + return None + return di_batch_job.annotations["fine_tuned_model"] diff --git a/model-engine/model_engine_server/infra/services/fake_llm_batch_completions_service.py b/model-engine/model_engine_server/infra/services/fake_llm_batch_completions_service.py new file mode 100644 index 00000000..bc72fc22 --- /dev/null +++ b/model-engine/model_engine_server/infra/services/fake_llm_batch_completions_service.py @@ -0,0 +1,87 @@ +from typing import Dict, Optional + +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.llms.batch_completion import ( + BatchCompletionsJob, + CreateBatchCompletionsEngineRequest, + UpdateBatchCompletionsV2Request, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.services.llm_batch_completions_service import ( + LLMBatchCompletionsService, +) + + +class FakeLLMBatchCompletionsService(LLMBatchCompletionsService): + def __init__( + self, + ): + self.jobs = [] + + async def create_batch_job( + self, + *, + user: User, + image_repo: str, + image_tag: str, + job_request: CreateBatchCompletionsEngineRequest, + resource_requests: CreateDockerImageBatchJobResourceRequests, + max_runtime_sec: int = 24 * 60 * 60, + labels: Dict[str, str] = {}, + num_workers: Optional[int] = 1, + ) -> BatchCompletionsJob: + """ + Create a batch completion job. + + Args: + owner: The user who requested the batch job + image_repo: The docker repo where the image is stored + image_tag: The tag of the batch completions image + job_config: The user-specified input to the batch job. Exposed as a file mounted at mount_location to the batch job + labels: Labels to apply to the batch job. + resource_requests: The resource requests for the batch job. + max_runtime_sec: The timeout of the batch job in seconds. + num_workers: The number of workers to run in the job. + + Returns: + The ID of the batch job. + """ + raise NotImplementedError() + + async def get_batch_job(self, batch_job_id: str, user: User) -> Optional[BatchCompletionsJob]: + """ + Get a batch job. + + Args: + batch_job_id: The ID of the batch job. + + Returns: + The batch job, or None if it does not exist. + """ + raise NotImplementedError() + + async def update_batch_job( + self, batch_job_id: str, request: UpdateBatchCompletionsV2Request, user: User + ) -> Optional[BatchCompletionsJob]: + """ + Get a batch job. + + Args: + batch_job_id: The ID of the batch job. + + Returns: + The batch job, or None if it does not exist. + """ + raise NotImplementedError() + + async def cancel_batch_job(self, batch_job_id: str, user: User) -> bool: + """ + Update a batch job. + + Args: + batch_job_id: The ID of the batch job. + + Returns: + Whether the batch job was updated successfully. + """ + return False diff --git a/model-engine/model_engine_server/infra/services/image_cache_service.py b/model-engine/model_engine_server/infra/services/image_cache_service.py new file mode 100644 index 00000000..f2b1dc28 --- /dev/null +++ b/model-engine/model_engine_server/infra/services/image_cache_service.py @@ -0,0 +1,196 @@ +from datetime import datetime +from typing import Dict, NamedTuple, Tuple + +import pytz +from model_engine_server.common.config import hmi_config +from model_engine_server.common.env_vars import CIRCLECI, GIT_TAG +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.entities import GpuType, ModelEndpointInfraState +from model_engine_server.domain.exceptions import DockerRepositoryNotFoundException +from model_engine_server.domain.repositories import DockerRepository +from model_engine_server.infra.gateways.resources.image_cache_gateway import ( + CachedImages, + ImageCacheGateway, +) +from model_engine_server.infra.repositories.model_endpoint_record_repository import ( + ModelEndpointRecordRepository, +) + +logger = make_logger(logger_name()) + +IMAGES_TO_CACHE_PER_INSTANCE_TYPE = 32 + +CachePriority = NamedTuple( + "CachePriority", + ( + ("is_high_priority", int), + ("has_no_available_workers", int), + ("last_updated_at", datetime), + ), +) + +DockerImage = NamedTuple( + "DockerImage", + ( + ("repo", str), + ("tag", str), + ), +) + + +class ImageCacheService: + """ + Represents reading from k8s and writing images to the k8s image cache. + """ + + def __init__( + self, + model_endpoint_record_repository: ModelEndpointRecordRepository, + image_cache_gateway: ImageCacheGateway, + docker_repository: DockerRepository, + ): + self.model_endpoint_record_repository = model_endpoint_record_repository + self.image_cache_gateway = image_cache_gateway + self.docker_repository = docker_repository + + def _cache_finetune_llm_images( + self, images_to_cache_priority: Dict[str, Dict[str, CachePriority]] + ): + """ + Cache images used by fine tune LLM endpoints to reduce cold start time. + """ + # a cache priority to ensure llm endpoint images are always prioritized + llm_image_cache_priority = CachePriority( + is_high_priority=1, # make it a high priority + has_no_available_workers=1, + # assuming it has no available workers so that it will be at top after reverse sorting + last_updated_at=datetime.max.replace(tzinfo=pytz.utc), + # setting it to max to ensure it will be at top after reverse sorting + ) + + istio_image = DockerImage("gcr.io/istio-release/proxyv2", "1.15.0") + tgi_image_110 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "1.1.0" + ) + vllm_image_027 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.2.7" + ) + vllm_image_032 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.3.2" + ) + latest_tag = "fake_docker_repository_latest_image_tag" + if not CIRCLECI: + try: # pragma: no cover + latest_tag = self.docker_repository.get_latest_image_tag( + hmi_config.batch_inference_vllm_repository + ) + except DockerRepositoryNotFoundException: + pass + vllm_batch_image_latest = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}", + latest_tag, + ) + forwarder_image = DockerImage(f"{infra_config().docker_repo_prefix}/model-engine", GIT_TAG) + + for llm_image in [ + istio_image, + tgi_image_110, + vllm_image_027, + vllm_image_032, + vllm_batch_image_latest, + forwarder_image, + ]: + if self.docker_repository.is_repo_name( + llm_image.repo + ) and not self.docker_repository.image_exists(llm_image.tag, llm_image.repo): + logger.warning( + f"Image {llm_image.repo}:{llm_image.tag} does not exist. Skipping caching ..." + ) + continue + image = f"{llm_image.repo}:{llm_image.tag}" + for key in ["a10", "a100", "h100", "h100_3g40gb", "h100_1g20gb"]: + images_to_cache_priority[key][image] = llm_image_cache_priority + + async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpointInfraState]]): + images_to_cache_priority: Dict[str, Dict[str, CachePriority]] = { + "cpu": {}, + "a10": {}, + "a100": {}, + "t4": {}, + "h100": {}, + "h100_3g40gb": {}, + "h100_1g20gb": {}, + } + + self._cache_finetune_llm_images(images_to_cache_priority) + + for endpoint_id, (_, state) in endpoint_infra_states.items(): + record = await self.model_endpoint_record_repository.get_model_endpoint_record( + endpoint_id + ) + + if record is None: + continue + + last_updated_at = ( + record.last_updated_at.replace(tzinfo=pytz.utc) + if record.last_updated_at is not None + else datetime.min.replace(tzinfo=pytz.utc) + ) + has_no_available_workers = int(state.deployment_state.available_workers == 0) + is_high_priority = int(state.high_priority is True) + + # TODO: Adding for image cache stability and to make it faster. Remove this + # condition when things are proven to run smoothly. + if not state.high_priority: + continue + + cache_priority = CachePriority( + is_high_priority=is_high_priority, + has_no_available_workers=has_no_available_workers, + last_updated_at=last_updated_at, + ) + + image_repository_and_tag = state.image.split("/", 1)[1] + repository_name, image_tag = image_repository_and_tag.split(":") + if state.resource_state.gpus == 0 and ( + ( + state.image not in images_to_cache_priority["cpu"] + or last_updated_at.replace(tzinfo=pytz.utc) + > images_to_cache_priority["cpu"][state.image].last_updated_at.replace( + tzinfo=pytz.utc + ) + ) + and self.docker_repository.image_exists(image_tag, repository_name) + ): + images_to_cache_priority["cpu"][state.image] = cache_priority + elif state.resource_state.gpus > 0: + for gpu_type, key in [ + (GpuType.NVIDIA_AMPERE_A10, "a10"), + (GpuType.NVIDIA_AMPERE_A100, "a100"), + (GpuType.NVIDIA_TESLA_T4, "t4"), + (GpuType.NVIDIA_HOPPER_H100, "h100"), + (GpuType.NVIDIA_HOPPER_H100_3G_40GB, "h100_3g40gb"), + (GpuType.NVIDIA_HOPPER_H100_1G_20GB, "h100_1g20gb"), + ]: + if state.resource_state.gpu_type == gpu_type and ( + ( + state.image not in images_to_cache_priority[key] + or last_updated_at.replace(tzinfo=pytz.utc) + > images_to_cache_priority[key][state.image].last_updated_at.replace( + tzinfo=pytz.utc + ) + ) + and self.docker_repository.image_exists(image_tag, repository_name) + ): + images_to_cache_priority[key][state.image] = cache_priority + images_to_cache = CachedImages( + cpu=[], a10=[], a100=[], t4=[], h100=[], h100_1g20gb=[], h100_3g40gb=[] + ) + for key, val in images_to_cache_priority.items(): + images_to_cache[key] = sorted( # type: ignore + val.keys(), key=lambda image: val[image], reverse=True + )[:IMAGES_TO_CACHE_PER_INSTANCE_TYPE] + + await self.image_cache_gateway.create_or_update_image_cache(images_to_cache) diff --git a/server/llm_engine_server/infra/services/live_batch_job_orchestration_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py similarity index 87% rename from server/llm_engine_server/infra/services/live_batch_job_orchestration_service.py rename to model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py index c847044a..76c6cd38 100644 --- a/server/llm_engine_server/infra/services/live_batch_job_orchestration_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_orchestration_service.py @@ -10,35 +10,36 @@ from datetime import datetime, timedelta from typing import List, Optional, Union -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.tasks import ( EndpointPredictV1Request, GetAsyncTaskV1Response, TaskStatus, ) -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.domain_exceptions import ObjectNotFoundException -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ( +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.entities import ( BatchJobProgress, BatchJobRecord, BatchJobSerializationFormat, BatchJobStatus, ModelEndpointStatus, ) -from llm_engine_server.domain.gateways import AsyncModelEndpointInferenceGateway -from llm_engine_server.domain.services import ModelEndpointService -from llm_engine_server.domain.use_cases.async_inference_use_cases import ( +from model_engine_server.domain.exceptions import ObjectNotFoundException +from model_engine_server.domain.gateways import AsyncModelEndpointInferenceGateway +from model_engine_server.domain.services import ModelEndpointService +from model_engine_server.domain.use_cases.async_inference_use_cases import ( DEFAULT_TASK_TIMEOUT_SECONDS, ) -from llm_engine_server.infra.gateways import BatchJobProgressGateway, FilesystemGateway -from llm_engine_server.infra.repositories.batch_job_record_repository import ( +from model_engine_server.infra.gateways import BatchJobProgressGateway +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +from model_engine_server.infra.repositories.batch_job_record_repository import ( BatchJobRecordRepository, ) -from llm_engine_server.infra.services.batch_job_orchestration_service import ( +from model_engine_server.infra.services.batch_job_orchestration_service import ( BatchJobOrchestrationService, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) @dataclass @@ -165,10 +166,7 @@ async def _run_batch_job( ) results = self._poll_tasks( - owner=owner, - job_id=job_id, - task_ids=task_ids, - timeout_timestamp=timeout_timestamp, + owner=owner, job_id=job_id, task_ids=task_ids, timeout_timestamp=timeout_timestamp ) result_location = batch_job_record.result_location @@ -204,10 +202,7 @@ async def _wait_for_endpoint_to_be_ready( model_endpoint = await self.model_endpoint_service.get_model_endpoint_record( model_endpoint_id=model_endpoint_id, ) - updating = { - ModelEndpointStatus.UPDATE_PENDING, - ModelEndpointStatus.UPDATE_IN_PROGRESS, - } + updating = {ModelEndpointStatus.UPDATE_PENDING, ModelEndpointStatus.UPDATE_IN_PROGRESS} assert model_endpoint while model_endpoint.status in updating: @@ -245,9 +240,7 @@ async def _read_or_submit_tasks( pending_task_ids_location = batch_job_record.task_ids_location if pending_task_ids_location is not None: with self.filesystem_gateway.open( - pending_task_ids_location, - "r", - aws_profile=ml_infra_config().profile_ml_worker, + pending_task_ids_location, "r", aws_profile=infra_config().profile_ml_worker ) as f: task_ids_serialized = f.read().splitlines() task_ids = [ @@ -261,9 +254,7 @@ async def _read_or_submit_tasks( task_ids = await self._submit_tasks(queue_name, input_path, task_name) pending_task_ids_location = self._get_pending_task_ids_location(job_id) with self.filesystem_gateway.open( - pending_task_ids_location, - "w", - aws_profile=ml_infra_config().profile_ml_worker, + pending_task_ids_location, "w", aws_profile=infra_config().profile_ml_worker ) as f: f.write("\n".join([tid.serialize() for tid in task_ids])) await self.batch_job_record_repository.update_batch_job_record( @@ -291,7 +282,7 @@ def _create_task( inputs: List[BatchEndpointInferencePrediction] = [] with self.filesystem_gateway.open( - input_path, "r", aws_profile=ml_infra_config().profile_ml_worker + input_path, "r", aws_profile=infra_config().profile_ml_worker ) as f: # Increase the CSV reader's field limit size from the default (131072) csv.field_size_limit(sys.maxsize) @@ -337,8 +328,7 @@ def _poll_tasks( self.batch_job_progress_gateway.update_progress(owner, job_id, progress) while pending_task_ids_set: new_results = executor.map( - self.async_model_endpoint_inference_gateway.get_task, - pending_task_ids_set, + self.async_model_endpoint_inference_gateway.get_task, pending_task_ids_set ) has_new_ready_tasks = False curr_timestamp = datetime.utcnow() @@ -362,8 +352,7 @@ def _poll_tasks( results = [ BatchEndpointInferencePredictionResponse( - response=task_id_to_result[task_id], - reference_id=task_id_to_ref_id_map[task_id], + response=task_id_to_result[task_id], reference_id=task_id_to_ref_id_map[task_id] ) for task_id in task_ids_only ] @@ -383,14 +372,14 @@ def _serialize_and_write_results( results_serialized = pickle.dumps(results) with self.filesystem_gateway.open( - result_location, "wb", aws_profile=ml_infra_config().profile_ml_worker + result_location, "wb", aws_profile=infra_config().profile_ml_worker ) as f: f.write(results_serialized) @staticmethod def _get_pending_task_ids_location(job_id: str) -> str: - return f"s3://{ml_infra_config().s3_bucket}/llm-engine/batch-jobs/{job_id}/pending_task_ids.txt" + return f"s3://{infra_config().s3_bucket}/launch/batch-jobs/{job_id}/pending_task_ids.txt" @staticmethod def _get_job_result_location(job_id: str) -> str: - return f"s3://{ml_infra_config().s3_bucket}/llm-engine/batch-jobs/{job_id}/result.json" + return f"s3://{infra_config().s3_bucket}/launch/batch-jobs/{job_id}/result.json" diff --git a/server/llm_engine_server/infra/services/live_batch_job_service.py b/model-engine/model_engine_server/infra/services/live_batch_job_service.py similarity index 88% rename from server/llm_engine_server/infra/services/live_batch_job_service.py rename to model-engine/model_engine_server/infra/services/live_batch_job_service.py index 7e49c699..aa5029c7 100644 --- a/server/llm_engine_server/infra/services/live_batch_job_service.py +++ b/model-engine/model_engine_server/infra/services/live_batch_job_service.py @@ -1,8 +1,8 @@ from typing import Dict, Optional -from llm_engine_server.common.dtos.batch_jobs import CreateBatchJobResourceRequests -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.batch_jobs import CreateBatchJobResourceRequests +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.entities import ( BatchJob, BatchJobProgress, BatchJobSerializationFormat, @@ -10,17 +10,18 @@ GpuType, ModelEndpointType, ) -from llm_engine_server.domain.exceptions import EndpointResourceInvalidRequestException -from llm_engine_server.domain.services import BatchJobService, ModelEndpointService -from llm_engine_server.infra.gateways import BatchJobOrchestrationGateway, BatchJobProgressGateway -from llm_engine_server.infra.repositories.batch_job_record_repository import ( +from model_engine_server.domain.exceptions import EndpointResourceInvalidRequestException +from model_engine_server.domain.services import BatchJobService, ModelEndpointService +from model_engine_server.infra.gateways import BatchJobOrchestrationGateway, BatchJobProgressGateway +from model_engine_server.infra.repositories.batch_job_record_repository import ( BatchJobRecordRepository, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) DEFAULT_ENDPOINT_CPUS_BATCH_JOB = 3 DEFAULT_ENDPOINT_MEMORY_BATCH_JOB = "12Gi" +DEFAULT_ENDPOINT_STORAGE_BATCH_JOB = "16Gi" # to match launch-python-client endpoint default DEFAULT_ENDPOINT_GPUS_BATCH_JOB = 1 DEFAULT_ENDPOINT_GPU_TYPE_BATCH_JOB = GpuType.NVIDIA_TESLA_T4 DEFAULT_ENDPOINT_MAX_WORKERS_BATCH_JOB = 50 @@ -76,6 +77,7 @@ async def create_batch_job( else DEFAULT_ENDPOINT_GPUS_BATCH_JOB ) memory = resource_requests.memory or DEFAULT_ENDPOINT_MEMORY_BATCH_JOB + storage = resource_requests.storage or DEFAULT_ENDPOINT_STORAGE_BATCH_JOB gpu_type = None if gpus == 0 and resource_requests.gpu_type is not None: raise EndpointResourceInvalidRequestException( @@ -101,7 +103,8 @@ async def create_batch_job( gpus=gpus, # type: ignore memory=memory, # type: ignore gpu_type=gpu_type, # type: ignore - storage=resource_requests.storage, + storage=storage, + nodes_per_worker=1, # TODO batch jobs currently doesn't support multinode, since async multinode isn't supported yet optimize_costs=False, min_workers=0, max_workers=max_workers, # type: ignore diff --git a/server/llm_engine_server/infra/services/live_endpoint_builder_service.py b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py similarity index 77% rename from server/llm_engine_server/infra/services/live_endpoint_builder_service.py rename to model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py index 35b93e28..ae3836c6 100644 --- a/server/llm_engine_server/infra/services/live_endpoint_builder_service.py +++ b/model-engine/model_engine_server/infra/services/live_endpoint_builder_service.py @@ -2,33 +2,31 @@ import json import os import tempfile +import time from contextlib import AsyncExitStack from logging import LoggerAdapter -from typing import List, Optional, Sequence +from typing import Dict, List, Optional, Sequence, Set from datadog import statsd -from llm_engine_server.common.constants import ( - FEATURE_FLAG_USE_MULTI_CONTAINER_ARCHITECTURE_FOR_ARTIFACTLIKE_BUNDLE, -) -from llm_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse -from llm_engine_server.common.dtos.endpoint_builder import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.common.dtos.endpoint_builder import ( BuildEndpointRequest, BuildEndpointResponse, BuildEndpointStatus, ) -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.common.env_vars import LOCAL -from llm_engine_server.common.io import open_wrapper -from llm_engine_server.common.serialization_utils import bool_to_str -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.domain_exceptions import DockerBuildFailedException -from llm_engine_server.core.loggers import make_logger -from llm_engine_server.core.notification_gateway import NotificationApp, NotificationGateway -from llm_engine_server.core.utils.env import environment -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.env_vars import LOCAL +from model_engine_server.common.io import open_wrapper +from model_engine_server.common.serialization_utils import bool_to_str +from model_engine_server.core.config import infra_config +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.core.notification_gateway import NotificationApp, NotificationGateway +from model_engine_server.core.utils.env import environment +from model_engine_server.domain.entities import ( ArtifactLike, - CloudpickleArtifactFlavor, CustomFramework, + ModelBundle, ModelBundleFlavorType, ModelEndpointConfig, ModelEndpointDeploymentState, @@ -42,28 +40,34 @@ TensorflowFramework, ZipArtifactFlavor, ) -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.domain.gateways import MonitoringMetricsGateway -from llm_engine_server.domain.repositories import DockerRepository -from llm_engine_server.domain.services import EndpointBuilderService -from llm_engine_server.domain.use_cases.model_endpoint_use_cases import ( +from model_engine_server.domain.exceptions import ( + DockerBuildFailedException, + EndpointResourceInfraException, +) +from model_engine_server.domain.gateways import MonitoringMetricsGateway +from model_engine_server.domain.repositories import DockerRepository +from model_engine_server.domain.services import EndpointBuilderService +from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( CONVERTED_FROM_ARTIFACT_LIKE_KEY, ) -from llm_engine_server.infra.gateways import FilesystemGateway -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) -from llm_engine_server.infra.infra_utils import make_exception_log -from llm_engine_server.infra.repositories import FeatureFlagRepository, ModelEndpointCacheRepository -from llm_engine_server.infra.repositories.model_endpoint_record_repository import ( +from model_engine_server.infra.infra_utils import make_exception_log +from model_engine_server.infra.repositories import ( + FeatureFlagRepository, + ModelEndpointCacheRepository, +) +from model_engine_server.infra.repositories.model_endpoint_record_repository import ( ModelEndpointRecordRepository, ) if LOCAL: with environment(KUBERNETES_SERVICE_HOST=None): - logger = make_logger("llm_engine_server.service_builder") + logger = make_logger(logger_name()) else: - logger = make_logger("llm_engine_server.service_builder") + logger = make_logger(logger_name()) __all__: Sequence[str] = ( "INITIAL_K8S_CACHE_TTL_SECONDS", @@ -74,10 +78,37 @@ ECR_AWS_PROFILE: str = os.getenv("ECR_READ_AWS_PROFILE", "default") # type: ignore GIT_TAG: str = os.getenv("GIT_TAG") # type: ignore ENV: str = os.getenv("DD_ENV") # type: ignore +WORKSPACE_PATH = os.getenv("WORKSPACE", ".") INITIAL_K8S_CACHE_TTL_SECONDS: int = 60 MAX_IMAGE_TAG_LEN = 128 +RESTRICTED_ENV_VARS_KEYS = { + "BASE": [ + "DD_TRACE_ENABLED", + "DD_AGENT_HOST", + "DD_ENV", + "DD_SERVICE", + "DD_VERSION", + "OMP_THREAD_LIMIT", + ], + "TRITON": [ + "AWS_PROFILE", + ], + "CELERY": [ + "CELERY_ELASTICACHE_ENABLED", + "CELERY_QUEUE", + "CELERY_TASK_VISIBILITY", + "S3_BUCKET", + ], + "TEMPORAL": [ + "TEMPORAL_TASK_QUEUE", + ], + "HTTP": [ + "HTTP_PORT", + ], +} + class LiveEndpointBuilderService(EndpointBuilderService): def __init__( @@ -103,6 +134,7 @@ def __init__( async def build_endpoint( self, build_endpoint_request: BuildEndpointRequest ) -> BuildEndpointResponse: + time_build_endpoint_start = time.time() self.monitoring_metrics_gateway.emit_attempted_build_metric() logger_extra = build_endpoint_request.dict() @@ -118,12 +150,6 @@ async def build_endpoint( self._validate_build_endpoint_request(build_endpoint_request) - use_multi_container_architecture_for_artifactlike_bundle = ( - await self.feature_flag_repo.read_feature_flag_bool( - FEATURE_FLAG_USE_MULTI_CONTAINER_ARCHITECTURE_FOR_ARTIFACTLIKE_BUNDLE - ) - ) - async with AsyncExitStack() as stack: lock_ctx = self.model_endpoint_record_repository.get_lock_context(model_endpoint_record) lock = await stack.enter_async_context(lock_ctx) @@ -139,29 +165,32 @@ async def build_endpoint( try: # First, build the image if the model bundle does not have a docker image if not model_bundle.is_runnable(): - if use_multi_container_architecture_for_artifactlike_bundle: - assert isinstance( - model_bundle.flavor, CloudpickleArtifactFlavor - ) or isinstance(model_bundle.flavor, ZipArtifactFlavor) - logger_adapter.info( - f"Create a new runnable image model bundle for artifact flavor model bundle {model_bundle.id=} ..." - ) + logger_adapter.info( + f"Create a new runnable image model bundle for artifact flavor model bundle {model_bundle.id=} ..." + ) logger_adapter.info("Building base & user image...") # Build service image in two steps for better caching. # First we build a base image, which is expected to be shared between # many different bundles. try: - base_image_params = self._get_base_image_params( + base_image_params = self.get_base_image_params( build_endpoint_request, logger_adapter ) + logger_adapter.info(f"base_image_params: {base_image_params}") base_image = await self._build_image( - base_image_params, build_endpoint_request, logger_adapter + base_image_params, + build_endpoint_request, + logger_adapter, + "base", ) user_image_params = self._get_user_image_params( base_image, build_endpoint_request, logger_adapter ) image = await self._build_image( - user_image_params, build_endpoint_request, logger_adapter + user_image_params, + build_endpoint_request, + logger_adapter, + "user", ) image_repo = user_image_params.repo @@ -189,43 +218,40 @@ async def build_endpoint( inject_bundle_image_params, build_endpoint_request, logger_adapter, + "inject_bundle", ) # Now that it's no longer needed, clean up serialized bundle file to save storage - model_bundle_path = inject_bundle_image_params.substitution_args[ # type: ignore + model_bundle_path = inject_bundle_image_params.substitution_args[ + # type: ignore "LOCAL_BUNDLE_PATH" ] if os.path.exists(model_bundle_path): os.remove(model_bundle_path) else: - logger.error(f"No bundle object found at {model_bundle_path}!") + logger_adapter.error( + f"No bundle object found at {model_bundle_path}!" + ) except DockerBuildFailedException: log_error("Failed to build base and user docker images") self.monitoring_metrics_gateway.emit_docker_failed_build_metric() raise - if use_multi_container_architecture_for_artifactlike_bundle: - self.convert_artifact_like_bundle_to_runnable_image( - build_endpoint_request, image_repo, image_tag - ) + self.convert_artifact_like_bundle_to_runnable_image( + build_endpoint_request, image_repo, image_tag + ) - # CONVERTED_FROM_ARTIFACT_LIKE_KEY will be checked by `get_endpoint_resource_arguments_from_request()` in k8s_resource_types.py - if not model_endpoint_record.metadata: - model_endpoint_record.metadata = {} - model_endpoint_record.metadata.update( - {CONVERTED_FROM_ARTIFACT_LIKE_KEY: True} - ) - await self.model_endpoint_record_repository.update_model_endpoint_record( - model_endpoint_id=endpoint_id, - metadata=model_endpoint_record.metadata, - ) + # CONVERTED_FROM_ARTIFACT_LIKE_KEY will be checked by `get_endpoint_resource_arguments_from_request()` in k8s_resource_types.py + if not model_endpoint_record.metadata: + model_endpoint_record.metadata = {} + model_endpoint_record.metadata.update({CONVERTED_FROM_ARTIFACT_LIKE_KEY: True}) else: flavor = model_bundle.flavor assert isinstance(flavor, RunnableImageLike) repository = ( - f"{ml_infra_config().docker_repo_prefix}/{flavor.repository}" + f"{infra_config().docker_repo_prefix}/{flavor.repository}" if self.docker_repository.is_repo_name(flavor.repository) else flavor.repository ) @@ -251,6 +277,13 @@ async def build_endpoint( except EndpointResourceInfraException: log_error("K8s resource update failed") raise + finally: + # Clean up CONVERTED_FROM_ARTIFACT_LIKE_KEY as it is for internal use only + if ( + model_endpoint_record.metadata is not None + and CONVERTED_FROM_ARTIFACT_LIKE_KEY in model_endpoint_record.metadata + ): + del model_endpoint_record.metadata[CONVERTED_FROM_ARTIFACT_LIKE_KEY] endpoint_info = ModelEndpointInfraState( deployment_name=build_endpoint_request.deployment_name, @@ -269,11 +302,13 @@ async def build_endpoint( memory=build_endpoint_request.memory, gpu_type=build_endpoint_request.gpu_type, storage=build_endpoint_request.storage, + nodes_per_worker=build_endpoint_request.nodes_per_worker, optimize_costs=build_endpoint_request.optimize_costs, ), user_config_state=ModelEndpointUserConfigState( app_config=build_endpoint_request.model_endpoint_record.current_model_bundle.app_config, endpoint_config=ModelEndpointConfig( + endpoint_type=build_endpoint_request.model_endpoint_record.endpoint_type, endpoint_name=build_endpoint_request.model_endpoint_record.name, bundle_name=build_endpoint_request.model_endpoint_record.current_model_bundle.name, post_inference_hooks=build_endpoint_request.post_inference_hooks, @@ -334,6 +369,13 @@ async def build_endpoint( except Exception: # noqa log_error(f"[Continuing] Failed to emit successful build metric for {endpoint_id=}") + try: + self.monitoring_metrics_gateway.emit_build_time_metric( + time.time() - time_build_endpoint_start + ) + except Exception: # noqa + log_error(f"[Continuing] Failed to emit endpoint build time metric for {endpoint_id=}") + return BuildEndpointResponse(status=BuildEndpointStatus.OK) def convert_artifact_like_bundle_to_runnable_image( @@ -343,7 +385,7 @@ def convert_artifact_like_bundle_to_runnable_image( image_tag: str, ) -> None: """ - With LLMEngine Inference Re-Architecture, we want to deploy endpoints with ArtifactLike bundle using + With Launch Inference Re-Architecture, we want to deploy endpoints with ArtifactLike bundle using multi-container architecture, which RunnableImageFlavor has already adopted. This function mutates the build_endpoint_request by converting the ArtifactLike bundle flavor into @@ -357,29 +399,30 @@ def convert_artifact_like_bundle_to_runnable_image( assert isinstance(model_bundle.flavor, ArtifactLike) new_model_bundle = model_bundle.copy() - if ml_infra_config().env == "circleci": - ml_infra_service_config_file = "config.yaml" + if infra_config().env == "circleci": + infra_config_file = "config.yaml" else: - ml_infra_service_config_file = ml_infra_config().env + ".yaml" + infra_config_file = infra_config().env + ".yaml" new_flavor = RunnableImageFlavor( flavor=ModelBundleFlavorType.RUNNABLE_IMAGE, repository=image_repo, tag=image_tag, + readiness_initial_delay_seconds=30, command=[ "dumb-init", "--", "ddtrace-run", "python", "-m", - "llm_engine_server.inference.sync_inference.start_fastapi_server", + "model_engine_server.inference.sync_inference.start_fastapi_server", ], env={ - "OMP_NUM_THREADS": '"1"', + "OMP_NUM_THREADS": "1", "BASE_PATH": "/app", "BUNDLE_URL": model_bundle.flavor.location, "AWS_PROFILE": build_endpoint_request.aws_role, - "RESULTS_S3_BUCKET": ml_infra_config().s3_bucket, + "RESULTS_S3_BUCKET": infra_config().s3_bucket, "CHILD_FN_INFO": json.dumps( build_endpoint_request.child_fn_info if build_endpoint_request.child_fn_info @@ -387,7 +430,7 @@ def convert_artifact_like_bundle_to_runnable_image( ), "PREWARM": bool_to_str(build_endpoint_request.prewarm) or "false", "PORT": "5005", - "ML_INFRA_SERVICES_CONFIG_PATH": f"/app/ml_infra_core/llm_engine_server.core/llm_engine_server.core/configs/{ml_infra_service_config_file}", + "ML_INFRA_SERVICES_CONFIG_PATH": f"/app/model-engine/model_engine_server/core/configs/{infra_config_file}", }, protocol="http", ) @@ -395,19 +438,19 @@ def convert_artifact_like_bundle_to_runnable_image( if isinstance(model_bundle.flavor, ZipArtifactFlavor): if new_flavor.env is None: new_flavor.env = {} - new_flavor.env[ - "LOAD_PREDICT_FN_MODULE_PATH" - ] = model_bundle.flavor.load_predict_fn_module_path - new_flavor.env[ - "LOAD_MODEL_FN_MODULE_PATH" - ] = model_bundle.flavor.load_model_fn_module_path + new_flavor.env["LOAD_PREDICT_FN_MODULE_PATH"] = ( + model_bundle.flavor.load_predict_fn_module_path + ) + new_flavor.env["LOAD_MODEL_FN_MODULE_PATH"] = ( + model_bundle.flavor.load_model_fn_module_path + ) new_model_bundle.flavor = new_flavor new_model_bundle.model_artifact_ids = [] build_endpoint_request.model_endpoint_record.current_model_bundle = new_model_bundle - def _get_base_image_params( + def get_base_image_params( self, build_endpoint_request: BuildEndpointRequest, logger_adapter: LoggerAdapter, @@ -453,14 +496,15 @@ def _get_base_image_params( raise ValueError(f"Unsupported framework_type: {env_params.framework_type}") # The context should be whatever WORKDIR is in the container running the build app itself. - inference_folder = "llm_engine/llm_engine/inference" - base_path: str = os.getenv("WORKSPACE") # type: ignore + inference_folder = "model-engine/model_engine_server/inference" + logger_adapter.info(f"inference_folder: {inference_folder}") + logger_adapter.info(f"dockerfile: {inference_folder}/{dockerfile}") return BuildImageRequest( - repo="llm-engine", + repo=hmi_config.user_inference_base_repository, image_tag=resulting_image_tag[:MAX_IMAGE_TAG_LEN], aws_profile=ECR_AWS_PROFILE, # type: ignore - base_path=base_path, + base_path=WORKSPACE_PATH, dockerfile=f"{inference_folder}/{dockerfile}", base_image=base_image, requirements_folder=None, @@ -488,7 +532,7 @@ def _get_user_image_params( dockerfile = "pytorch_or_tf.user.Dockerfile" service_image_tag = self._get_image_tag(base_image_tag, GIT_TAG, requirements_hash) - ecr_repo = "hosted-model-inference/async-pytorch" + ecr_repo = hmi_config.user_inference_pytorch_repository elif isinstance(env_params, TensorflowFramework): if build_endpoint_request.gpus > 0: raise NotImplementedError("Tensorflow GPU image not supported yet") @@ -500,7 +544,7 @@ def _get_user_image_params( raise ValueError("Tensorflow version must be specified if the framework is TF.") dockerfile = "pytorch_or_tf.user.Dockerfile" service_image_tag = self._get_image_tag(tensorflow_version, GIT_TAG, requirements_hash) - ecr_repo = "hosted-model-inference/async-tensorflow-cpu" + ecr_repo = hmi_config.user_inference_tensorflow_repository elif isinstance(env_params, CustomFramework): if ( env_params.image_tag is None or env_params.image_repository is None @@ -514,10 +558,8 @@ def _get_user_image_params( raise ValueError(f"Unsupported framework_type: {env_params.framework_type}") # The context should be whatever WORKDIR is in the container running the build app itself. - inference_folder = "llm_engine/llm_engine/inference" - base_path: str = os.getenv("WORKSPACE") # type: ignore - - requirements_folder = os.path.join(base_path, f"requirements_{requirements_hash}") + inference_folder = "model-engine/model_engine_server/inference" + requirements_folder = os.path.join(WORKSPACE_PATH, f"requirements_{requirements_hash}") try: os.mkdir(requirements_folder) except FileExistsError: @@ -535,7 +577,7 @@ def _get_user_image_params( repo=ecr_repo, image_tag=service_image_tag[:MAX_IMAGE_TAG_LEN], aws_profile=ECR_AWS_PROFILE, - base_path=base_path, + base_path=WORKSPACE_PATH, dockerfile=f"{inference_folder}/{dockerfile}", base_image=base_image, requirements_folder=requirements_folder, @@ -551,9 +593,11 @@ def _get_inject_bundle_image_params( ) -> BuildImageRequest: model_endpoint_record = build_endpoint_request.model_endpoint_record model_bundle = model_endpoint_record.current_model_bundle + assert isinstance(model_bundle.flavor, ZipArtifactFlavor) bundle_id = model_bundle.id service_image_str = "-".join([base_image_params.image_tag, GIT_TAG, bundle_id]) + # nosemgrep service_image_hash = hashlib.md5(str(service_image_str).encode("utf-8")).hexdigest() service_image_tag = f"inject-bundle-image-{service_image_hash}" ecr_repo = base_image_params.repo @@ -564,17 +608,15 @@ def _get_inject_bundle_image_params( # The context should be whatever WORKDIR is in the container running the build app itself. dockerfile = "inject_bundle.Dockerfile" - inference_folder = "llm_engine/llm_engine/inference" - base_path: str = os.getenv("WORKSPACE") # type: ignore - - bundle_folder = os.path.join(base_path, f"bundle_{service_image_hash}") + inference_folder = "model-engine/model_engine_server/inference" + bundle_folder = os.path.join(WORKSPACE_PATH, f"bundle_{service_image_hash}") try: os.mkdir(bundle_folder) except FileExistsError: pass _, model_bundle_path = tempfile.mkstemp(dir=bundle_folder, suffix=".zip") bundle_url = model_bundle.location - logger.info( + logger_adapter.info( f"Downloading bundle from serialized object at location {bundle_url} to local path {model_bundle_path}" ) with open_wrapper(bundle_url, "rb") as bundle_data: # type: ignore @@ -584,14 +626,14 @@ def _get_inject_bundle_image_params( substitution_args = { "LOCAL_BUNDLE_PATH": model_bundle_path, "LOAD_MODEL_MODULE_PATH": model_bundle.flavor.load_model_fn_module_path, # type: ignore - "LOAD_PREDICT_MODULE_PATH": model_bundle.flavor.load_predict_fn_module_path, # type: ignore + "LOAD_PREDICT_MODULE_PATH": model_bundle.flavor.load_predict_fn_module_path, } return BuildImageRequest( repo=ecr_repo, image_tag=service_image_tag[:MAX_IMAGE_TAG_LEN], aws_profile=ECR_AWS_PROFILE, - base_path=base_path, + base_path=WORKSPACE_PATH, dockerfile=f"{inference_folder}/{dockerfile}", base_image=base_image, requirements_folder=bundle_folder, @@ -603,6 +645,7 @@ async def _build_image( image_params: BuildImageRequest, build_endpoint_request: BuildEndpointRequest, logger_adapter: LoggerAdapter, + image_type: str, ) -> str: """ Builds the service image and updates the endpoint status if the image building fails. @@ -625,22 +668,24 @@ async def _build_image( image_tag=image_params.image_tag, aws_profile=ECR_AWS_PROFILE, ): + self.monitoring_metrics_gateway.emit_image_build_cache_miss_metric(image_type) tags = [ f"kube_deployment:{build_endpoint_request.deployment_name}", f"user_id:{user_id}", ] - with statsd.timed("kaniko.build_time", tags=tags): + with statsd.timed(f"kaniko.{image_type}_build_time", tags=tags): try: build_result: BuildImageResponse = self.docker_repository.build_image( image_params, ) build_result_status = build_result.status build_result_logs: str = build_result.logs + logger_adapter.info(f"Image Build job: {build_result.job_name}") except Exception: # noqa build_result_status = False s3_logs_location: Optional[str] = None log_error( - "Unknown error encountered on image build" + "Unknown error encountered on image build. " f"No logs to write for {model_endpoint_name}, since docker build threw exception" ) else: @@ -653,7 +698,7 @@ async def _build_image( with self.filesystem_gateway.open( s3_logs_location, "w", - aws_profile=ml_infra_config().profile_ml_worker, + aws_profile=infra_config().profile_ml_worker, ) as file_out: file_out.write(build_result_logs) except Exception: # noqa @@ -682,11 +727,11 @@ async def _build_image( help_url = self.filesystem_gateway.generate_signed_url( s3_logs_location, expiration=43200, # 12 hours - aws_profile=ml_infra_config().profile_ml_worker, + aws_profile=infra_config().profile_ml_worker, ) else: help_url = ( - "https://app.datadoghq.com/logs?query=service%3Allm-engine-" + "https://app.datadoghq.com/logs?query=service%3Alaunch-" f"endpoint-builder%20env%3A{ENV}&cols=host%2Cservice&" "index=%2A&messageDisplay=inline&stream_sort=time%2C" "desc&viz=stream&live=true" @@ -702,7 +747,7 @@ async def _build_image( ) self.notification_gateway.send_notification( - title="LLMEngine Endpoint Build Failed", + title="Launch Endpoint Build Failed", description=message, help_url=help_url, notification_apps=[ @@ -715,9 +760,9 @@ async def _build_image( raise DockerBuildFailedException(f"Image build failed ({endpoint_id=})") else: + self.monitoring_metrics_gateway.emit_image_build_cache_hit_metric(image_type) logger_adapter.info( - f"Image {image_params.repo}:{image_params.image_tag} already exists, " - f"skipping build for {endpoint_id=}" + f"Image already exists, skipping build. Image={image_params.repo}:{image_params.image_tag}, {endpoint_id=}" ) return self.docker_repository.get_image_url(image_params.image_tag, image_params.repo) @@ -728,8 +773,8 @@ def _validate_build_endpoint_request( ) -> None: """Raises ValueError if the request's AWS role isn't allowed.""" allowed_aws_roles = { - ml_infra_config().profile_ml_worker, - ml_infra_config().profile_ml_inference_worker, + infra_config().profile_ml_worker, + infra_config().profile_ml_inference_worker, } if build_endpoint_request.aws_role not in allowed_aws_roles: @@ -738,9 +783,34 @@ def _validate_build_endpoint_request( f"{allowed_aws_roles}." ) + model_bundle: ModelBundle = ( + build_endpoint_request.model_endpoint_record.current_model_bundle + ) + if isinstance(model_bundle.flavor, RunnableImageLike) and model_bundle.flavor.env: + restriced_env_vars = LiveEndpointBuilderService._get_restricted_env_vars( + model_bundle.flavor.env + ) + if len(restriced_env_vars) > 0: + raise ValueError( + f"Runnable image endpoints cannot set the following env vars: {restriced_env_vars}" + ) + if ( + not isinstance(model_bundle.flavor, RunnableImageLike) + and build_endpoint_request.nodes_per_worker > 1 + ): + raise ValueError( + "Multi-node deployment is only supported for RunnableImageLike model bundles." + ) + + @staticmethod + def _get_restricted_env_vars(env_vars: Dict[str, str]) -> Set[str]: + restricted_env_vars = set(key for keys in RESTRICTED_ENV_VARS_KEYS.values() for key in keys) + return set(env_vars.keys()) & restricted_env_vars + @staticmethod def _get_requirements_hash(requirements: List[str]) -> str: """Identifying hash for endpoint's Python requirements.""" + # nosemgrep return hashlib.md5("\n".join(sorted(requirements)).encode("utf-8")).hexdigest()[:6] @staticmethod @@ -757,4 +827,4 @@ def _get_service_builder_logs_location(user_id: str, endpoint_name: str) -> str: This function uses creates a key from the endpoint's name and owning user's ID. It uses an S3 bucket that is accessible by the Gateway & Service Builder. """ - return f"s3://{ml_infra_config().s3_bucket}/service_builder_logs/{user_id}_{endpoint_name}" + return f"s3://{infra_config().s3_bucket}/service_builder_logs/{user_id}_{endpoint_name}" diff --git a/model-engine/model_engine_server/infra/services/live_llm_batch_completions_service.py b/model-engine/model_engine_server/infra/services/live_llm_batch_completions_service.py new file mode 100644 index 00000000..ad792365 --- /dev/null +++ b/model-engine/model_engine_server/infra/services/live_llm_batch_completions_service.py @@ -0,0 +1,181 @@ +from dataclasses import dataclass +from datetime import datetime, timedelta +from typing import Dict, Optional + +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.llms import ( + BatchCompletionsJob, + BatchCompletionsJobStatus, + CreateBatchCompletionsEngineRequest, +) +from model_engine_server.common.dtos.llms.batch_completion import ( + UpdateBatchCompletionsV2Request, + UpdateBatchCompletionsV2Response, +) +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.entities.batch_job_entity import BatchJobStatus +from model_engine_server.domain.gateways.docker_image_batch_job_gateway import ( + DockerImageBatchJobGateway, +) +from model_engine_server.domain.services.llm_batch_completions_service import ( + LLMBatchCompletionsService, +) + + +def to_dto(status: BatchJobStatus) -> BatchCompletionsJobStatus: + if status == BatchJobStatus.PENDING: + return BatchCompletionsJobStatus.Queued + if status == BatchJobStatus.RUNNING: + return BatchCompletionsJobStatus.Running + if status == BatchJobStatus.FAILURE: + return BatchCompletionsJobStatus.Failed + if status == BatchJobStatus.SUCCESS: + return BatchCompletionsJobStatus.Completed + if status == BatchJobStatus.CANCELLED: + return BatchCompletionsJobStatus.Cancelled + if status == BatchJobStatus.TIMEOUT: + return BatchCompletionsJobStatus.Failed + + return BatchCompletionsJobStatus.Unknown + + +@dataclass +class CustomJobMetadata: + """ + This is a workaround to the current DockerImageBatchJobGateway implementation + which doesn't store additional metadata we need for batch completions v2 + """ + + input_data_path: Optional[str] + output_data_path: str + expires_at: str + priority: Optional[str] + labels: Dict[str, str] + + +NULL_TOKEN = "null" + + +class LiveLLMBatchCompletionsService(LLMBatchCompletionsService): + def __init__( + self, + docker_image_batch_job_gateway: DockerImageBatchJobGateway, + ): + self.docker_image_batch_job_gateway = docker_image_batch_job_gateway + + def encode_metadata(self, metadata: CustomJobMetadata) -> Dict[str, str]: + return { + "__INT_input_data_path": metadata.input_data_path or NULL_TOKEN, + "__INT_output_data_path": metadata.output_data_path, + "__INT_expires_at": metadata.expires_at, + "__INT_priority": metadata.priority or NULL_TOKEN, + **{f"__LABEL_{key}": value for key, value in metadata.labels.items()}, + } + + def decode_metadata(self, metadata: Dict[str, str]) -> CustomJobMetadata: + labels = { + key.replace("__LABEL_", ""): value + for key, value in metadata.items() + if key.startswith("__LABEL") + } + + return CustomJobMetadata( + input_data_path=metadata.get("__INT_input_data_path", "unknown"), + output_data_path=metadata.get("__INT_output_data_path", "unknown"), + expires_at=metadata.get("__INT_expires_at", "unknown"), + priority=metadata.get("__INT_priority", "unknown"), + labels=labels, + ) + + async def create_batch_job( + self, + *, + user: User, + image_repo: str, + image_tag: str, + job_request: CreateBatchCompletionsEngineRequest, + resource_requests: CreateDockerImageBatchJobResourceRequests, + max_runtime_sec: int = 24 * 60 * 60, + labels: Dict[str, str] = {}, + num_workers: Optional[int] = 1, + ): + config_file_path = "/opt/config.json" + env = {"CONFIG_FILE": config_file_path} + command = [ + "dumb-init", + "--", + "/bin/bash", + "-c", + "ddtrace-run python vllm_batch.py", + ] + + expires_at = datetime.now() + timedelta(seconds=max_runtime_sec) + job_id = await self.docker_image_batch_job_gateway.create_docker_image_batch_job( + created_by=user.user_id, + owner=user.team_id, + job_config=job_request.model_dump(by_alias=True), + env=env, + command=command, + repo=image_repo, + tag=image_tag, + mount_location=config_file_path, + resource_requests=resource_requests, + labels=labels, + override_job_max_runtime_s=max_runtime_sec, + num_workers=num_workers, + annotations=self.encode_metadata( + CustomJobMetadata( + input_data_path=job_request.input_data_path, + output_data_path=job_request.output_data_path, + expires_at=expires_at.isoformat(), + priority=job_request.priority, + labels=job_request.labels, + ) + ), + ) + return BatchCompletionsJob( + job_id=job_id, + input_data_path=job_request.input_data_path, + output_data_path=job_request.output_data_path, + model_config=job_request.model_cfg, + priority=job_request.priority, + status=BatchCompletionsJobStatus.Queued, + created_at=datetime.now().isoformat(), + expires_at=expires_at.isoformat(), + completed_at=None, + metadata={"labels": job_request.labels}, + ) + + async def get_batch_job(self, batch_job_id: str, user: User) -> Optional[BatchCompletionsJob]: + job = await self.docker_image_batch_job_gateway.get_docker_image_batch_job( + batch_job_id=batch_job_id + ) + + if job is None: + return None + + custom_metadata = self.decode_metadata(job.annotations or {}) + model_config = "[Cannot retrieve] -- please check the job logs" + + return BatchCompletionsJob( + job_id=batch_job_id, + input_data_path=custom_metadata.input_data_path, + output_data_path=custom_metadata.output_data_path, + model_config=model_config, + priority=custom_metadata.priority, + status=to_dto(job.status), + created_at=job.created_at, + expires_at=custom_metadata.expires_at, + completed_at=job.completed_at, + metadata={"labels": custom_metadata.labels}, + ) + + async def update_batch_job( + self, batch_job_id: str, request: UpdateBatchCompletionsV2Request, user: User + ) -> UpdateBatchCompletionsV2Response: + raise NotImplementedError("Not supported") + + async def cancel_batch_job(self, batch_job_id: str, user: User) -> bool: + return await self.docker_image_batch_job_gateway.update_docker_image_batch_job( + batch_job_id=batch_job_id, cancel=True + ) diff --git a/server/llm_engine_server/infra/services/live_llm_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_llm_model_endpoint_service.py similarity index 79% rename from server/llm_engine_server/infra/services/live_llm_model_endpoint_service.py rename to model-engine/model_engine_server/infra/services/live_llm_model_endpoint_service.py index 0eb4e3e4..644e0df6 100644 --- a/server/llm_engine_server/infra/services/live_llm_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_llm_model_endpoint_service.py @@ -1,15 +1,15 @@ from typing import List, Optional -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ModelEndpoint -from llm_engine_server.domain.services import LLMModelEndpointService -from llm_engine_server.infra.repositories.model_endpoint_record_repository import ( +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.domain.services import LLMModelEndpointService +from model_engine_server.infra.repositories.model_endpoint_record_repository import ( ModelEndpointRecordRepository, ) -from llm_engine_server.infra.services import LiveModelEndpointService +from model_engine_server.infra.services import LiveModelEndpointService -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) class LiveLLMModelEndpointService(LLMModelEndpointService): diff --git a/server/llm_engine_server/infra/services/live_model_endpoint_service.py b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py similarity index 85% rename from server/llm_engine_server/infra/services/live_model_endpoint_service.py rename to model-engine/model_engine_server/infra/services/live_model_endpoint_service.py index 5f671676..8cda63c5 100644 --- a/server/llm_engine_server/infra/services/live_model_endpoint_service.py +++ b/model-engine/model_engine_server/infra/services/live_model_endpoint_service.py @@ -1,14 +1,10 @@ from typing import Any, Dict, List, Optional from datadog import statsd -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.common.settings import generate_deployment_name -from llm_engine_server.core.domain_exceptions import ( - ObjectAlreadyExistsException, - ObjectNotFoundException, -) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.common.settings import generate_deployment_name +from model_engine_server.core.loggers import logger_name, make_logger +from model_engine_server.domain.entities import ( CallbackAuth, CpuSpecificationType, GpuType, @@ -20,24 +16,32 @@ ModelEndpointType, StorageSpecificationType, ) -from llm_engine_server.domain.exceptions import EndpointDeleteFailedException -from llm_engine_server.domain.gateways import ( +from model_engine_server.domain.exceptions import ( + EndpointDeleteFailedException, + ObjectAlreadyExistsException, + ObjectNotFoundException, +) +from model_engine_server.domain.gateways import ( AsyncModelEndpointInferenceGateway, ModelEndpointsSchemaGateway, StreamingModelEndpointInferenceGateway, SyncModelEndpointInferenceGateway, ) -from llm_engine_server.domain.services import ModelEndpointService -from llm_engine_server.infra.gateways import ModelEndpointInfraGateway -from llm_engine_server.infra.repositories import ModelEndpointCacheRepository -from llm_engine_server.infra.repositories.model_endpoint_record_repository import ( +from model_engine_server.domain.gateways.inference_autoscaling_metrics_gateway import ( + InferenceAutoscalingMetricsGateway, +) +from model_engine_server.domain.services import ModelEndpointService +from model_engine_server.domain.use_cases.model_endpoint_use_cases import MODEL_BUNDLE_CHANGED_KEY +from model_engine_server.infra.gateways import ModelEndpointInfraGateway +from model_engine_server.infra.repositories import ModelEndpointCacheRepository +from model_engine_server.infra.repositories.model_endpoint_record_repository import ( ModelEndpointRecordRepository, ) -logger = make_logger(filename_wo_ext(__file__)) +logger = make_logger(logger_name()) -STATSD_CACHE_HIT_NAME = "llm_engine_server.get_infra_state.cache_hit" -STATSD_CACHE_MISS_NAME = "llm_engine_server.get_infra_state.cache_miss" +STATSD_CACHE_HIT_NAME = "launch.get_infra_state.cache_hit" +STATSD_CACHE_MISS_NAME = "launch.get_infra_state.cache_miss" class LiveModelEndpointService(ModelEndpointService): @@ -50,6 +54,8 @@ def __init__( streaming_model_endpoint_inference_gateway: StreamingModelEndpointInferenceGateway, sync_model_endpoint_inference_gateway: SyncModelEndpointInferenceGateway, model_endpoints_schema_gateway: ModelEndpointsSchemaGateway, + inference_autoscaling_metrics_gateway: InferenceAutoscalingMetricsGateway, + can_scale_http_endpoint_from_zero_flag: bool, ): self.model_endpoint_record_repository = model_endpoint_record_repository self.model_endpoint_infra_gateway = model_endpoint_infra_gateway @@ -58,6 +64,8 @@ def __init__( self.streaming_model_endpoint_inference_gateway = streaming_model_endpoint_inference_gateway self.sync_model_endpoint_inference_gateway = sync_model_endpoint_inference_gateway self.model_endpoints_schema_gateway = model_endpoints_schema_gateway + self.inference_autoscaling_metrics_gateway = inference_autoscaling_metrics_gateway + self.can_scale_http_endpoint_from_zero_flag = can_scale_http_endpoint_from_zero_flag def get_async_model_endpoint_inference_gateway( self, @@ -74,6 +82,11 @@ def get_streaming_model_endpoint_inference_gateway( ) -> StreamingModelEndpointInferenceGateway: return self.streaming_model_endpoint_inference_gateway + def get_inference_autoscaling_metrics_gateway( + self, + ) -> InferenceAutoscalingMetricsGateway: + return self.inference_autoscaling_metrics_gateway + async def _get_model_endpoint_infra_state( self, record: ModelEndpointRecord, use_cache: bool ) -> Optional[ModelEndpointInfraState]: @@ -134,7 +147,8 @@ async def create_model_endpoint( gpus: int, memory: StorageSpecificationType, gpu_type: Optional[GpuType], - storage: Optional[StorageSpecificationType], + storage: StorageSpecificationType, + nodes_per_worker: int, optimize_costs: bool, min_workers: int, max_workers: int, @@ -144,6 +158,7 @@ async def create_model_endpoint( results_s3_bucket: str, prewarm: bool, high_priority: Optional[bool], + billing_tags: Optional[Dict[str, Any]] = None, owner: str, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth], @@ -181,6 +196,7 @@ async def create_model_endpoint( memory=memory, gpu_type=gpu_type, storage=storage, + nodes_per_worker=nodes_per_worker, optimize_costs=optimize_costs, aws_role=aws_role, results_s3_bucket=results_s3_bucket, @@ -267,6 +283,7 @@ async def update_model_endpoint( labels: Optional[Dict[str, str]] = None, prewarm: Optional[bool] = None, high_priority: Optional[bool] = None, + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth] = None, public_inference: Optional[bool] = None, @@ -297,6 +314,12 @@ async def update_model_endpoint( # f"Resource update on endpoint {name} in progress, try again later" # ) + if record.current_model_bundle.id != model_bundle_id: + if metadata is None: + metadata = record.metadata if record.metadata is not None else {} + # MODEL_BUNDLE_CHANGED_KEY will be checked during _create_deployment in K8SEndpointResourceDelegate + metadata[MODEL_BUNDLE_CHANGED_KEY] = True + record = await self.model_endpoint_record_repository.update_model_endpoint_record( model_endpoint_id=model_endpoint_id, model_bundle_id=model_bundle_id, @@ -324,9 +347,15 @@ async def update_model_endpoint( default_callback_url=default_callback_url, default_callback_auth=default_callback_auth, ) + + # Clean up MODEL_BUNDLE_CHANGED_KEY as it is only for internal use + if metadata is not None and MODEL_BUNDLE_CHANGED_KEY in metadata: + del metadata[MODEL_BUNDLE_CHANGED_KEY] + await self.model_endpoint_record_repository.update_model_endpoint_record( model_endpoint_id=model_endpoint_id, creation_task_id=creation_task_id, + metadata=metadata, ) record = await self.model_endpoint_record_repository.get_model_endpoint_record( @@ -373,3 +402,6 @@ async def delete_model_endpoint(self, model_endpoint_id: str) -> None: ) logger.info(f"Endpoint delete released lock for {created_by}, {name}") + + def can_scale_http_endpoint_from_zero(self) -> bool: + return self.can_scale_http_endpoint_from_zero_flag diff --git a/server/llm_engine_server/infra/services/model_endpoint_cache_service.py b/model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py similarity index 76% rename from server/llm_engine_server/infra/services/model_endpoint_cache_service.py rename to model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py index 824370f5..7e193027 100644 --- a/server/llm_engine_server/infra/services/model_endpoint_cache_service.py +++ b/model-engine/model_engine_server/infra/services/model_endpoint_cache_service.py @@ -1,13 +1,13 @@ from typing import Dict, Tuple -from llm_engine_server.domain.entities import ModelEndpointInfraState -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( +from model_engine_server.domain.entities import ModelEndpointInfraState +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, ) -from llm_engine_server.infra.repositories.model_endpoint_cache_repository import ( +from model_engine_server.infra.repositories.model_endpoint_cache_repository import ( ModelEndpointCacheRepository, ) -from llm_engine_server.infra.services.image_cache_service import ImageCacheService +from model_engine_server.infra.services.image_cache_service import ImageCacheService class ModelEndpointCacheWriteService: @@ -26,9 +26,9 @@ def __init__( self.image_cache_service = image_cache_service async def execute(self, ttl_seconds: float): - endpoint_infra_states: Dict[ - str, Tuple[bool, ModelEndpointInfraState] - ] = await self.resource_gateway.get_all_resources() + endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpointInfraState]] = ( + await self.resource_gateway.get_all_resources() + ) for key, (is_key_an_endpoint_id, state) in endpoint_infra_states.items(): if is_key_an_endpoint_id: diff --git a/model-engine/model_engine_server/service_builder/__init__.py b/model-engine/model_engine_server/service_builder/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/model_engine_server/service_builder/celery.py b/model-engine/model_engine_server/service_builder/celery.py new file mode 100644 index 00000000..06384c9e --- /dev/null +++ b/model-engine/model_engine_server/service_builder/celery.py @@ -0,0 +1,25 @@ +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.env_vars import CIRCLECI +from model_engine_server.core.celery import celery_app +from model_engine_server.core.config import infra_config + +service_builder_broker_type: str +if CIRCLECI: + service_builder_broker_type = str(BrokerType.REDIS.value) +elif infra_config().cloud_provider == "azure": + service_builder_broker_type = str(BrokerType.SERVICEBUS.value) +else: + service_builder_broker_type = str(BrokerType.SQS.value) + +service_builder_service = celery_app( + name="model_engine_server.service_builder", + modules=[ + "model_engine_server.service_builder.tasks_v1", + ], + s3_bucket=infra_config().s3_bucket, + broker_type=service_builder_broker_type, + backend_protocol="abs" if infra_config().cloud_provider == "azure" else "s3", +) + +if __name__ == "__main__": + service_builder_service.start() diff --git a/model-engine/model_engine_server/service_builder/tasks_v1.py b/model-engine/model_engine_server/service_builder/tasks_v1.py new file mode 100644 index 00000000..cd4ff63c --- /dev/null +++ b/model-engine/model_engine_server/service_builder/tasks_v1.py @@ -0,0 +1,136 @@ +import asyncio +import os +from typing import Any, Dict + +import aioredis +from celery.signals import worker_process_init +from model_engine_server.api.dependencies import get_monitoring_metrics_gateway +from model_engine_server.common.config import hmi_config +from model_engine_server.common.constants import READYZ_FPATH +from model_engine_server.common.dtos.endpoint_builder import ( + BuildEndpointRequest, + BuildEndpointResponse, +) +from model_engine_server.common.env_vars import CIRCLECI +from model_engine_server.core.config import infra_config +from model_engine_server.core.fake_notification_gateway import FakeNotificationGateway +from model_engine_server.db.base import get_session_async_null_pool +from model_engine_server.domain.repositories import DockerRepository +from model_engine_server.infra.gateways import ( + ABSFilesystemGateway, + ASBInferenceAutoscalingMetricsGateway, + RedisInferenceAutoscalingMetricsGateway, + S3FilesystemGateway, +) +from model_engine_server.infra.gateways.resources.asb_queue_endpoint_resource_delegate import ( + ASBQueueEndpointResourceDelegate, +) +from model_engine_server.infra.gateways.resources.fake_queue_endpoint_resource_delegate import ( + FakeQueueEndpointResourceDelegate, +) +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( + set_lazy_load_kubernetes_clients, +) +from model_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( + LiveEndpointResourceGateway, +) +from model_engine_server.infra.gateways.resources.queue_endpoint_resource_delegate import ( + QueueEndpointResourceDelegate, +) +from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( + SQSQueueEndpointResourceDelegate, +) +from model_engine_server.infra.repositories import ( + ACRDockerRepository, + DbModelEndpointRecordRepository, + ECRDockerRepository, + FakeDockerRepository, + RedisFeatureFlagRepository, + RedisModelEndpointCacheRepository, +) +from model_engine_server.infra.services import LiveEndpointBuilderService +from model_engine_server.service_builder.celery import service_builder_service + +# Need to disable lazy loading of k8s clients because each event loop should contain its own k8s +# client, which constructs the aiohttp.ClientSession in the event loop. +set_lazy_load_kubernetes_clients(False) + + +def get_live_endpoint_builder_service( + session: Any, + redis: aioredis.Redis, +): + queue_delegate: QueueEndpointResourceDelegate + if CIRCLECI: + queue_delegate = FakeQueueEndpointResourceDelegate() + elif infra_config().cloud_provider == "azure": + queue_delegate = ASBQueueEndpointResourceDelegate() + else: + queue_delegate = SQSQueueEndpointResourceDelegate( + sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) + ) + notification_gateway = FakeNotificationGateway() + monitoring_metrics_gateway = get_monitoring_metrics_gateway() + docker_repository: DockerRepository + if CIRCLECI: + docker_repository = FakeDockerRepository() + elif infra_config().docker_repo_prefix.endswith("azurecr.io"): + docker_repository = ACRDockerRepository() + else: + docker_repository = ECRDockerRepository() + inference_autoscaling_metrics_gateway = ( + ASBInferenceAutoscalingMetricsGateway() + if infra_config().cloud_provider == "azure" + else RedisInferenceAutoscalingMetricsGateway(redis_client=redis) + ) + service = LiveEndpointBuilderService( + docker_repository=docker_repository, + resource_gateway=LiveEndpointResourceGateway( + queue_delegate=queue_delegate, + inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, + ), + monitoring_metrics_gateway=monitoring_metrics_gateway, + model_endpoint_record_repository=DbModelEndpointRecordRepository( + monitoring_metrics_gateway=monitoring_metrics_gateway, session=session, read_only=False + ), + model_endpoint_cache_repository=RedisModelEndpointCacheRepository(redis_client=redis), + filesystem_gateway=( + ABSFilesystemGateway() + if infra_config().cloud_provider == "azure" + else S3FilesystemGateway() + ), + notification_gateway=notification_gateway, + feature_flag_repo=RedisFeatureFlagRepository(redis_client=redis), + ) + + return service + + +async def _build_endpoint( + build_endpoint_request: BuildEndpointRequest, +) -> BuildEndpointResponse: + session = get_session_async_null_pool() + pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) + redis = aioredis.Redis(connection_pool=pool) + service: LiveEndpointBuilderService = get_live_endpoint_builder_service(session, redis) + + response = await service.build_endpoint(build_endpoint_request) + await redis.close() + await pool.disconnect() + return response + + +@worker_process_init.connect +def init_worker(*args, **kwargs): + # k8s health check + with open(READYZ_FPATH, "w") as f: + f.write("READY") + + +@service_builder_service.task +def build_endpoint(build_endpoint_request_json: Dict[str, Any]) -> Dict[str, str]: + build_endpoint_request: BuildEndpointRequest = BuildEndpointRequest.parse_obj( + build_endpoint_request_json + ) + result = asyncio.run(_build_endpoint(build_endpoint_request)) + return result.dict() diff --git a/model-engine/mypy.ini b/model-engine/mypy.ini new file mode 100644 index 00000000..fc499b32 --- /dev/null +++ b/model-engine/mypy.ini @@ -0,0 +1,30 @@ +[mypy] +ignore_missing_imports = True +follow_imports = silent +show_column_numbers = True +namespace_packages = True +explicit_package_bases = True +strict_optional = True +plugins = pydantic.mypy +exclude = clients|.*/triton_model_repo/.* + +[mypy-model_engine_server.core.*] +ignore_errors = True + +[mypy-model_engine_server.db.*] +ignore_errors = True + +[mypy-model_engine_server.db.base] +ignore_errors = False + +[mypy-model_engine_server.infra.repositories.*] +ignore_errors = True + +[mypy-clients.*] +ignore_errors = True + +[mypy-tests.*] +ignore_errors = True + +[mypy-model_engine_server.common.types.gen.openai] +ignore_errors = True \ No newline at end of file diff --git a/server/requirements-test.txt b/model-engine/requirements-test.txt similarity index 79% rename from server/requirements-test.txt rename to model-engine/requirements-test.txt index 719527a1..0115722b 100644 --- a/server/requirements-test.txt +++ b/model-engine/requirements-test.txt @@ -1,14 +1,21 @@ +aioresponses>=0.7.6 +coverage==5.5 +diff-cover==7.7.0 +frozendict==2.3.4 +func-timeout==4.3.5 multiprocess==0.70.14 -pytest==7.2.0 -pytest-cov==2.10.0 moto==3.1.12 -coverage==5.5 mypy==1.3.0 +pylint<3.0.0 +pytest==7.2.0 +pytest-asyncio==0.20.1 +pytest-cov==2.10.0 pytest-mypy==0.9.1 pytest-mypy-plugins==1.10.1 -pytest-asyncio==0.20.1 pytest-pylint==0.18.0 +requests-mock==1.9.3 types-cachetools==5.3.0.5 +types-croniter==1.4.0.0 types-PyYAML==6.0.7 types-redis==4.3.21.3 types-requests==2.27.26 @@ -19,4 +26,3 @@ types-toml==0.10.8 types-ujson==5.5.0 types-urllib3==1.26.14 types-waitress==2.1.4 -frozendict==2.3.4 diff --git a/model-engine/requirements.in b/model-engine/requirements.in new file mode 100644 index 00000000..d503f7b8 --- /dev/null +++ b/model-engine/requirements.in @@ -0,0 +1,62 @@ +GitPython~=3.1 +Jinja2==3.0.3 # version 3.1.0 had a bug +aiohttp~=3.9 +aioredis~=2.0 +alembic==1.8.1 +asyncpg==0.27.0 +azure-containerregistry~=1.2.0 +azure-identity~=1.15.0 +azure-keyvault-secrets~=4.7.0 +azure-servicebus~=7.11.4 +azure-storage-blob~=12.19.0 +boto3-stubs[essential]~=1.26.67 +boto3~=1.21 +botocore~=1.24 +build~=1.0.3 +celery[redis,sqs,tblib]~=5.4.0 +click~=8.1 +cloudpickle==2.1.0 +croniter==1.4.1 +cryptography>=42.0.4 # not used directly, but needs to be pinned for Microsoft security scan +dataclasses-json>=0.5.7 +datadog-api-client==2.11.0 +datadog~=0.47.0 +ddtrace==1.8.3 +deprecation~=2.1 +docker~=5.0 +fastapi~=0.110.0 +gitdb2~=2.0 +gunicorn~=20.0 +httptools==0.5.0 +json-log-formatter~=0.3 +kubeconfig~=1.1 +kubernetes-asyncio==25.11.0 +kubernetes~=25.3.0 +orjson==3.9.15 +protobuf~=3.20 +psycopg2-binary==2.9.3 +py-xid==0.3.0 +pycurl~=7.44 # For celery[sqs] +pydantic==2.8.2 +python-multipart~=0.0.7 +quart==0.18.3 +requests-auth-aws-sigv4~=0.7 +requests~=2.25 +rich~=12.6 +sentencepiece==0.1.99 +sh~=1.13 +smart-open~=5.2 +sqlalchemy[asyncio]~=2.0.4 +sse-starlette==1.6.1 +sseclient-py==1.7.2 +starlette[full]>=0.36.2 # not used directly, but needs to be pinned for Microsoft security scan +stringcase==1.2.0 +tenacity>=6.0.0,<=6.2.0 +testing-postgresql==1.3.0 +tokenizers~=0.15.2 +tqdm~=4.64 +transformers==4.38.0 +twine==3.7.1 +uvicorn==0.30.6 +uvloop==0.17.0 +yarl~=1.4 \ No newline at end of file diff --git a/model-engine/requirements.txt b/model-engine/requirements.txt new file mode 100644 index 00000000..6e784ecc --- /dev/null +++ b/model-engine/requirements.txt @@ -0,0 +1,578 @@ +# +# This file is autogenerated by pip-compile with Python 3.10 +# by the following command: +# +# pip-compile --allow-unsafe --index-url=https://pypi.org/simple --no-emit-index-url --no-emit-trusted-host model-engine/requirements.in +# +aiofiles==23.1.0 + # via quart +aiohttp==3.9.2 + # via + # -r model-engine/requirements.in + # kubernetes-asyncio +aioredis==2.0.1 + # via -r model-engine/requirements.in +aiosignal==1.3.1 + # via aiohttp +alembic==1.8.1 + # via -r model-engine/requirements.in +amqp==5.1.1 + # via kombu +annotated-types==0.7.0 + # via pydantic +anyio==3.7.1 + # via + # azure-core + # httpx + # starlette +asn1crypto==1.5.1 + # via scramp +async-timeout==4.0.2 + # via + # aiohttp + # aioredis + # redis +asyncpg==0.27.0 + # via -r model-engine/requirements.in +attrs==23.1.0 + # via + # aiohttp + # cattrs + # ddtrace + # jsonschema + # referencing +azure-common==1.1.28 + # via azure-keyvault-secrets +azure-containerregistry==1.2.0 + # via -r model-engine/requirements.in +azure-core==1.29.6 + # via + # azure-containerregistry + # azure-identity + # azure-keyvault-secrets + # azure-servicebus + # azure-storage-blob +azure-identity==1.15.0 + # via -r model-engine/requirements.in +azure-keyvault-secrets==4.7.0 + # via -r model-engine/requirements.in +azure-servicebus==7.11.4 + # via -r model-engine/requirements.in +azure-storage-blob==12.19.0 + # via -r model-engine/requirements.in +billiard==4.2.0 + # via celery +bleach==6.0.0 + # via readme-renderer +blinker==1.6.2 + # via quart +boto3==1.28.1 + # via + # -r model-engine/requirements.in + # celery + # kombu +boto3-stubs[essential]==1.26.67 + # via -r model-engine/requirements.in +botocore==1.31.1 + # via + # -r model-engine/requirements.in + # boto3 + # s3transfer +botocore-stubs==1.29.165 + # via boto3-stubs +build==1.0.3 + # via -r model-engine/requirements.in +bytecode==0.14.2 + # via ddtrace +cachetools==5.3.1 + # via google-auth +cattrs==23.1.2 + # via ddtrace +celery[redis,sqs,tblib]==5.4.0 + # via -r model-engine/requirements.in +certifi==2023.7.22 + # via + # datadog-api-client + # httpcore + # httpx + # kubernetes + # kubernetes-asyncio + # requests +cffi==1.16.0 + # via cryptography +charset-normalizer==3.2.0 + # via requests +click==8.1.4 + # via + # -r model-engine/requirements.in + # celery + # click-didyoumean + # click-plugins + # click-repl + # quart + # uvicorn +click-didyoumean==0.3.0 + # via celery +click-plugins==1.1.1 + # via celery +click-repl==0.3.0 + # via celery +cloudpickle==2.1.0 + # via -r model-engine/requirements.in +colorama==0.4.6 + # via twine +commonmark==0.9.1 + # via rich +croniter==1.4.1 + # via -r model-engine/requirements.in +cryptography==42.0.5 + # via + # -r model-engine/requirements.in + # azure-identity + # azure-storage-blob + # msal + # pyjwt + # secretstorage +dataclasses-json==0.5.9 + # via -r model-engine/requirements.in +datadog==0.47.0 + # via -r model-engine/requirements.in +datadog-api-client==2.11.0 + # via -r model-engine/requirements.in +ddsketch==2.0.4 + # via ddtrace +ddtrace==1.8.3 + # via -r model-engine/requirements.in +deprecation==2.1.0 + # via -r model-engine/requirements.in +docker==5.0.3 + # via -r model-engine/requirements.in +docutils==0.20.1 + # via readme-renderer +envier==0.4.0 + # via ddtrace +exceptiongroup==1.2.0 + # via + # anyio + # cattrs +fastapi==0.110.0 + # via -r model-engine/requirements.in +filelock==3.13.1 + # via + # huggingface-hub + # transformers +frozenlist==1.3.3 + # via + # aiohttp + # aiosignal +fsspec==2023.10.0 + # via huggingface-hub +gitdb==4.0.10 + # via gitpython +gitdb2==2.0.6 + # via -r model-engine/requirements.in +gitpython==3.1.41 + # via -r model-engine/requirements.in +google-auth==2.21.0 + # via kubernetes +greenlet==2.0.2 + # via sqlalchemy +gunicorn==20.1.0 + # via -r model-engine/requirements.in +h11==0.14.0 + # via + # httpcore + # hypercorn + # uvicorn + # wsproto +h2==4.1.0 + # via hypercorn +hpack==4.0.0 + # via h2 +httpcore==1.0.4 + # via httpx +httptools==0.5.0 + # via -r model-engine/requirements.in +httpx==0.27.0 + # via starlette +huggingface-hub==0.20.3 + # via + # tokenizers + # transformers +hypercorn==0.14.4 + # via quart +hyperframe==6.0.1 + # via h2 +idna==3.7 + # via + # anyio + # httpx + # requests + # yarl +importlib-metadata==6.8.0 + # via + # keyring + # twine +isodate==0.6.1 + # via + # azure-containerregistry + # azure-keyvault-secrets + # azure-servicebus + # azure-storage-blob +itsdangerous==2.1.2 + # via + # quart + # starlette +jaraco-classes==3.3.0 + # via keyring +jeepney==0.8.0 + # via + # keyring + # secretstorage +jinja2==3.0.3 + # via + # -r model-engine/requirements.in + # quart + # starlette +jmespath==1.0.1 + # via + # boto3 + # botocore +json-log-formatter==0.5.2 + # via -r model-engine/requirements.in +jsonschema==4.19.0 + # via ddtrace +jsonschema-specifications==2023.7.1 + # via jsonschema +keyring==24.2.0 + # via twine +kombu[sqs]==5.3.5 + # via celery +kubeconfig==1.1.1 + # via -r model-engine/requirements.in +kubernetes==25.3.0 + # via -r model-engine/requirements.in +kubernetes-asyncio==25.11.0 + # via -r model-engine/requirements.in +mako==1.2.4 + # via alembic +markupsafe==2.1.3 + # via + # jinja2 + # mako + # quart + # werkzeug +marshmallow==3.19.0 + # via + # dataclasses-json + # marshmallow-enum +marshmallow-enum==1.5.1 + # via dataclasses-json +more-itertools==9.1.0 + # via jaraco-classes +msal==1.26.0 + # via + # azure-identity + # msal-extensions +msal-extensions==1.1.0 + # via azure-identity +multidict==6.0.4 + # via + # aiohttp + # yarl +mypy-boto3-cloudformation==1.26.156 + # via boto3-stubs +mypy-boto3-dynamodb==1.26.164 + # via boto3-stubs +mypy-boto3-ec2==1.26.157 + # via boto3-stubs +mypy-boto3-lambda==1.26.163 + # via boto3-stubs +mypy-boto3-rds==1.26.163 + # via boto3-stubs +mypy-boto3-s3==1.26.163 + # via boto3-stubs +mypy-boto3-sqs==1.26.148 + # via boto3-stubs +mypy-extensions==1.0.0 + # via typing-inspect +numpy==1.24.4 + # via transformers +oauthlib==3.2.2 + # via requests-oauthlib +orjson==3.9.15 + # via -r model-engine/requirements.in +packaging==23.1 + # via + # build + # ddtrace + # deprecation + # huggingface-hub + # marshmallow + # msal-extensions + # transformers +pg8000==1.29.8 + # via testing-postgresql +pkginfo==1.9.6 + # via twine +portalocker==2.8.2 + # via msal-extensions +priority==2.0.0 + # via hypercorn +prompt-toolkit==3.0.39 + # via click-repl +protobuf==3.20.3 + # via + # -r model-engine/requirements.in + # ddsketch + # ddtrace +psycopg2-binary==2.9.3 + # via -r model-engine/requirements.in +py-xid==0.3.0 + # via -r model-engine/requirements.in +pyasn1==0.5.0 + # via + # pyasn1-modules + # rsa +pyasn1-modules==0.3.0 + # via google-auth +pycparser==2.21 + # via cffi +pycurl==7.45.2 + # via + # -r model-engine/requirements.in + # celery + # kombu +pydantic==2.8.2 + # via + # -r model-engine/requirements.in + # fastapi +pydantic-core==2.20.1 + # via pydantic +pygments==2.15.1 + # via + # readme-renderer + # rich +pyjwt[crypto]==2.8.0 + # via + # msal + # pyjwt +pyproject-hooks==1.0.0 + # via build +python-dateutil==2.8.2 + # via + # botocore + # celery + # croniter + # datadog-api-client + # kubernetes + # kubernetes-asyncio + # pg8000 +python-multipart==0.0.7 + # via + # -r model-engine/requirements.in + # starlette +pyyaml==6.0.1 + # via + # huggingface-hub + # kubeconfig + # kubernetes + # kubernetes-asyncio + # starlette + # transformers +quart==0.18.3 + # via -r model-engine/requirements.in +readme-renderer==40.0 + # via twine +redis==4.6.0 + # via celery +referencing==0.30.2 + # via + # jsonschema + # jsonschema-specifications +regex==2023.10.3 + # via transformers +requests==2.31.0 + # via + # -r model-engine/requirements.in + # azure-core + # datadog + # docker + # huggingface-hub + # kubernetes + # msal + # requests-auth-aws-sigv4 + # requests-oauthlib + # requests-toolbelt + # transformers + # twine +requests-auth-aws-sigv4==0.7 + # via -r model-engine/requirements.in +requests-oauthlib==1.3.1 + # via kubernetes +requests-toolbelt==1.0.0 + # via twine +rfc3986==2.0.0 + # via twine +rich==12.6.0 + # via -r model-engine/requirements.in +rpds-py==0.10.0 + # via + # jsonschema + # referencing +rsa==4.9 + # via google-auth +s3transfer==0.6.1 + # via boto3 +safetensors==0.4.2 + # via transformers +scramp==1.4.4 + # via pg8000 +secretstorage==3.3.3 + # via keyring +sentencepiece==0.1.99 + # via -r model-engine/requirements.in +sh==1.14.3 + # via -r model-engine/requirements.in +six==1.16.0 + # via + # azure-core + # bleach + # ddsketch + # ddtrace + # google-auth + # isodate + # kubernetes + # kubernetes-asyncio + # python-dateutil + # tenacity +smart-open==5.2.1 + # via -r model-engine/requirements.in +smmap==5.0.0 + # via + # gitdb + # smmap2 +smmap2==3.0.1 + # via gitdb2 +sniffio==1.3.0 + # via + # anyio + # httpx +sqlalchemy[asyncio]==2.0.4 + # via + # -r model-engine/requirements.in + # alembic +sse-starlette==1.6.1 + # via -r model-engine/requirements.in +sseclient-py==1.7.2 + # via -r model-engine/requirements.in +starlette[full]==0.36.3 + # via + # -r model-engine/requirements.in + # fastapi + # sse-starlette +stringcase==1.2.0 + # via -r model-engine/requirements.in +tblib==2.0.0 + # via celery +tenacity==6.2.0 + # via + # -r model-engine/requirements.in + # ddtrace +testing-common-database==2.0.3 + # via testing-postgresql +testing-postgresql==1.3.0 + # via -r model-engine/requirements.in +tokenizers==0.15.2 + # via + # -r model-engine/requirements.in + # transformers +tomli==2.0.1 + # via + # build + # hypercorn + # pyproject-hooks +tqdm==4.65.0 + # via + # -r model-engine/requirements.in + # huggingface-hub + # transformers + # twine +transformers==4.38.0 + # via -r model-engine/requirements.in +twine==3.7.1 + # via -r model-engine/requirements.in +types-awscrt==0.16.23 + # via + # botocore-stubs + # types-s3transfer +types-s3transfer==0.6.1 + # via boto3-stubs +typing-extensions==4.10.0 + # via + # aioredis + # azure-core + # azure-keyvault-secrets + # azure-servicebus + # azure-storage-blob + # boto3-stubs + # cattrs + # datadog-api-client + # ddtrace + # fastapi + # huggingface-hub + # pydantic + # pydantic-core + # sqlalchemy + # typing-inspect + # uvicorn +typing-inspect==0.9.0 + # via dataclasses-json +tzdata==2023.3 + # via celery +urllib3==1.26.16 + # via + # botocore + # celery + # datadog-api-client + # google-auth + # kombu + # kubernetes + # kubernetes-asyncio + # requests +uvicorn==0.30.6 + # via -r model-engine/requirements.in +uvloop==0.17.0 + # via -r model-engine/requirements.in +vine==5.1.0 + # via + # amqp + # celery + # kombu +wcwidth==0.2.6 + # via prompt-toolkit +webencodings==0.5.1 + # via bleach +websocket-client==1.6.1 + # via + # docker + # kubernetes +werkzeug==2.3.6 + # via quart +wsproto==1.2.0 + # via hypercorn +xmltodict==0.13.0 + # via ddtrace +yarl==1.9.2 + # via + # -r model-engine/requirements.in + # aiohttp +zipp==3.16.0 + # via importlib-metadata + +# The following packages are considered to be unsafe in a requirements file: +setuptools==69.0.3 + # via + # gunicorn + # kubernetes + # kubernetes-asyncio diff --git a/model-engine/requirements_override.txt b/model-engine/requirements_override.txt new file mode 100644 index 00000000..0520f838 --- /dev/null +++ b/model-engine/requirements_override.txt @@ -0,0 +1,2 @@ +# Consists of packages that need to be explicitly different from those in requirements.txt +aioboto3==10.4.0 diff --git a/model-engine/service_configs/service_config_circleci.yaml b/model-engine/service_configs/service_config_circleci.yaml new file mode 100644 index 00000000..31644ad2 --- /dev/null +++ b/model-engine/service_configs/service_config_circleci.yaml @@ -0,0 +1,72 @@ +# Config to know where model-engine is running +gateway_namespace: default + +# Config for Model Engine running in CircleCI +model_primitive_host: "none" + +# Endpoint config +# K8s namespace the endpoints will be created in +endpoint_namespace: model-engine + +# Asynchronous endpoints +# TODO: Try out localstack once e2e tests have been updated to use sqs as a broker_type +sqs_profile: nonexistent_sqs_profile +sqs_queue_policy_template: > + { + "Version": "2012-10-17", + "Id": "__default_policy_ID", + "Statement": [ + { + "Sid": "__owner_statement", + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::000000000000:root" + }, + "Action": "sqs:*", + "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" + }, + { + "Effect": "Allow", + "Principal": { + "AWS": "arn:aws:iam::000000000000:role/default" + }, + "Action": "sqs:*", + "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" + } + ] + } + +sqs_queue_tag_template: > + { + "infra.scale.com/product": "MLInfraLaunchSQS", + "infra.scale.com/team": "${team}", + "infra.scale.com/contact": "yi.xu@scale.com", + "infra.scale.com/customer": "AllCustomers", + "infra.scale.com/financialOwner": "yi.xu@scale.com", + "Launch-Endpoint-Id": "${endpoint_id}", + "Launch-Endpoint-Name": "${endpoint_name}", + "Launch-Endpoint-Created-By": "${endpoint_created_by}" + } + +# Billing +billing_queue_arn: none +# There's a separate piece of infra that caches k8s state onto redis, so we need a url to it +cache_redis_aws_url: redis://127.0.0.1:6379/15 + +cloud_file_llm_fine_tune_repository: "s3://model-engine-integration-tests/fine_tune_repository/circleci" + +dd_trace_enabled: false +istio_enabled: true +sensitive_log_mode: false +tgi_repository: "text-generation-inference" +vllm_repository: "vllm" +lightllm_repository: "lightllm" +tensorrt_llm_repository: "tensorrt-llm" +batch_inference_vllm_repository: "llm-engine/batch-infer-vllm" +user_inference_base_repository: "launch/inference" +user_inference_pytorch_repository: "hosted-model-inference/async-pytorch" +user_inference_tensorflow_repository: "hosted-model-inference/async-tensorflow-cpu" +docker_image_layer_cache_repository: "kaniko-cache" + +# S3 access +hf_user_fine_tuned_weights_prefix: "s3://test-bucket/model-weights" diff --git a/model-engine/setup.cfg b/model-engine/setup.cfg new file mode 100644 index 00000000..6b2273a8 --- /dev/null +++ b/model-engine/setup.cfg @@ -0,0 +1,54 @@ +[aliases] +test=pytest + +[coverage:run] +omit = + model_engine_server/entrypoints/* + model_engine_server/api/app.py + model_engine_server/api/dependencies.py + model_engine_server/common/config.py + model_engine_server/common/io.py + model_engine_server/core/celery/app.py + model_engine_server/core/docker/ecr.py + model_engine_server/db/base.py + model_engine_server/infra/gateways/abs_file_storage_gateway.py + model_engine_server/infra/gateways/abs_filesystem_gateway.py + model_engine_server/infra/gateways/abs_llm_artifact_gateway.py + model_engine_server/infra/gateways/asb_inference_autoscaling_metrics_gateway.py + model_engine_server/infra/gateways/redis_inference_autoscaling_metrics_gateway.py + model_engine_server/infra/gateways/resources/asb_queue_endpoint_resource_delegate.py + model_engine_server/infra/gateways/resources/live_endpoint_resource_gateway.py + model_engine_server/infra/gateways/resources/k8s_resource_types.py + model_engine_server/infra/repositories/abs_file_llm_fine_tune_events_repository.py + model_engine_server/infra/repositories/abs_file_llm_fine_tune_repository.py + model_engine_server/infra/repositories/acr_docker_repository.py + +# TODO: Fix pylint errors +# [pylint] +# ignore-paths = test/* +# disable = +# I0011, +# R0801, R0902, R0903, R0913, +# W0703, W1202, W1203, W1514, +# C0114, C0411, +# E0611, +# W0511, +# W0622, +# output-format = colorized +# max-line-length = 120 + + +[tool:pytest] +addopts = + --verbose + --durations=0 + --cache-clear + --cov=model_engine_server + --cov-report=term-missing + --mypy + --mypy-ini-file=mypy.ini + --ignore=clients +# Need to specify this since pytest override mypy.ini See https://github.com/realpython/pytest-mypy/issues/123 + --ignore-glob=*triton_model_repo* +# --pylint +# --pylint-rcfile=setup.cfg diff --git a/model-engine/setup.py b/model-engine/setup.py new file mode 100644 index 00000000..bc0a0548 --- /dev/null +++ b/model-engine/setup.py @@ -0,0 +1,19 @@ +# To get circleci to work +from setuptools import find_packages, setup +setup( + name="model_engine_server", + version="1.0.0", + packages=[p for p in find_packages() if "tests" not in p], + install_requires=[], + entry_points={ + "console_scripts": [ + "start-service-builder=model_engine_server.start_service_builder:entrypoint", + "start-server=model_engine_server.start_server:entrypoint", + "start-fastapi-server=model_engine_server.entrypoints.start_fastapi_server:entrypoint", + "start-batch-job-orchestration=model_engine_server.entrypoints.start_batch_job_orchestration:entrypoint", + "hosted-inference-server=model_engine_server.entrypoints.hosted_inference_server:entrypoint", + "autogen=model_engine_server.scripts.autogenerate_client_and_docs:entrypoint", + "launch-admin=model_engine_server.cli.bin:entrypoint", + ], + } +) diff --git a/model-engine/tests/README.md b/model-engine/tests/README.md new file mode 100644 index 00000000..ed230099 --- /dev/null +++ b/model-engine/tests/README.md @@ -0,0 +1,7 @@ +## To Run Tests: + +```shell +pushd ../ +PYTHONPATH=hosted_model_inference WORKSPACE=. python3 -m pytest hosted_model_inference/tests --cov=hosted_model_inference +popd +``` diff --git a/model-engine/tests/__init__.py b/model-engine/tests/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/tests/integration/__init__.py b/model-engine/tests/integration/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/tests/integration/inference/conftest.py b/model-engine/tests/integration/inference/conftest.py similarity index 83% rename from server/tests/integration/inference/conftest.py rename to model-engine/tests/integration/inference/conftest.py index ec6eee64..07e900b0 100644 --- a/server/tests/integration/inference/conftest.py +++ b/model-engine/tests/integration/inference/conftest.py @@ -13,10 +13,10 @@ from fastapi import Depends, FastAPI, Request from fastapi.responses import JSONResponse from fastapi.security import HTTPBasic, HTTPBasicCredentials -from llm_engine_server.common.constants import READYZ_FPATH -from llm_engine_server.common.serialization_utils import python_json_to_b64 -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.domain.entities import CallbackAuth, CallbackBasicAuth, ModelEndpointConfig +from model_engine_server.common.constants import READYZ_FPATH +from model_engine_server.common.serialization_utils import python_json_to_b64 +from model_engine_server.core.config import infra_config +from model_engine_server.domain.entities import CallbackAuth, CallbackBasicAuth, ModelEndpointConfig from tenacity import Retrying, retry_if_exception_type, stop_after_attempt, wait_fixed MODULE_PATH = Path(__file__).resolve() @@ -47,7 +47,7 @@ def test_user_id() -> str: @pytest.fixture(scope="session") def test_default_callback_auth() -> CallbackAuth: return CallbackAuth( - __root__=CallbackBasicAuth(kind="basic", username="test_user", password="test_password") + root=CallbackBasicAuth(kind="basic", username="test_user", password="test_password") ) @@ -67,6 +67,8 @@ def endpoint_config_location(callback_port: int, test_user_id: str) -> Iterator[ post_inference_hooks=["callback"], default_callback_url=f"http://localhost:{callback_port}/v0/callback", user_id=test_user_id, + billing_queue=None, + billing_tags=None, ).serialize() with NamedTemporaryFile(mode="w+") as f: f.write(endpoint_config_serialized) @@ -75,30 +77,30 @@ def endpoint_config_location(callback_port: int, test_user_id: str) -> Iterator[ @pytest.fixture(scope="session") -def llm_engine_celery_app( +def launch_celery_app( queue: str, user_config_location: str, endpoint_config_location: str ) -> Iterator[subprocess.Popen]: env = dict( - AWS_PROFILE="default", + AWS_PROFILE="default" if os.getenv("CIRCLECI") else infra_config().profile_ml_worker, BROKER_TYPE="redis", USE_REDIS_LOCALHOST=1, - CELERY_S3_BUCKET=ml_infra_config().s3_bucket, - RESULTS_S3_BUCKET=ml_infra_config().s3_bucket, + CELERY_S3_BUCKET=infra_config().s3_bucket, + RESULTS_S3_BUCKET=infra_config().s3_bucket, CHILD_FN_INFO="{}", BASE_PATH=str(BASE_PATH), PREWARM=True, - BUNDLE_URL=f"s3://{ml_infra_config().s3_bucket}/model_bundles/61a67d767bce560024c7eb96/f0142411-51e1-4357-a405-ee5fef87d977", + BUNDLE_URL=f"s3://{infra_config().s3_bucket}/model_bundles/61a67d767bce560024c7eb96/f0142411-51e1-4357-a405-ee5fef87d977", USER_CONFIG_LOCATION=user_config_location, ENDPOINT_CONFIG_LOCATION=endpoint_config_location, ) env_str = " ".join(f"{k}={v}" for k, v in env.items()) command = ( - f"{env_str} exec celery --app=llm_engine_server.inference.async_inference worker " + f"{env_str} exec celery --app=model_engine_server.inference.async_inference worker " f"--loglevel=INFO --concurrency=1 --queues={queue}" ) # Wait up to 10 seconds for process to start and be ready. - with subprocess.Popen( + with subprocess.Popen( # nosemgrep command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE ) as process: for attempt in Retrying( diff --git a/server/tests/integration/inference/test_async_inference.py b/model-engine/tests/integration/inference/test_async_inference.py similarity index 82% rename from server/tests/integration/inference/test_async_inference.py rename to model-engine/tests/integration/inference/test_async_inference.py index 91010221..d1d7f7c5 100644 --- a/server/tests/integration/inference/test_async_inference.py +++ b/model-engine/tests/integration/inference/test_async_inference.py @@ -4,20 +4,23 @@ import subprocess from functools import lru_cache from typing import Any, List, Optional, Tuple +from unittest.mock import MagicMock +import botocore import pytest import redis import requests from fastapi import FastAPI -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.model_endpoints import BrokerType +from model_engine_server.common.dtos.tasks import ( CallbackAuth, EndpointPredictV1Request, ResponseSchema, TaskStatus, ) -from llm_engine_server.common.env_vars import CIRCLECI -from llm_engine_server.infra.gateways import ( +from model_engine_server.common.env_vars import CIRCLECI +from model_engine_server.domain.exceptions import InvalidRequestException +from model_engine_server.infra.gateways import ( CeleryTaskQueueGateway, LiveAsyncModelEndpointInferenceGateway, ) @@ -39,13 +42,13 @@ def redis_available() -> bool: @pytest.mark.parametrize( "task_args,cloudpickle,expected_status,expected_result", [ - ({"y": 1}, False, TaskStatus.SUCCESS, ResponseSchema(__root__={"result": "1"})), + ({"y": 1}, False, TaskStatus.SUCCESS, ResponseSchema(root={"result": "1"})), ({"x": False, "y": 1}, False, TaskStatus.FAILURE, None), ], ) def test_submit_and_get_tasks( queue: str, - llm_engine_celery_app: subprocess.Popen, + launch_celery_app: subprocess.Popen, callback_app: FastAPI, task_args: List[Any], cloudpickle: bool, @@ -94,7 +97,7 @@ def test_async_callbacks( queue: str, callback_port: int, test_user_id: str, - llm_engine_celery_app: subprocess.Popen, + launch_celery_app: subprocess.Popen, callback_app: FastAPI, callback_version: Optional[str], expected_callback_payload: Any, @@ -157,3 +160,24 @@ def test_async_callbacks( assert actual_payload == expected_callback_payload assert callback_stats["last_auth"][callback_version] == expected_credentials + + +def test_async_callbacks_botocore_exception( + queue: str, +): + gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) + + mock_dest = MagicMock() + mock_dest.send_task = MagicMock( + side_effect=botocore.exceptions.ClientError(error_response={}, operation_name="") + ) + mock_get = MagicMock() + mock_get.return_value = mock_dest + gateway._get_celery_dest = mock_get + + with pytest.raises(InvalidRequestException): + gateway.send_task( + task_name="test_task", + queue_name=queue, + args=[1, 2], + ) diff --git a/model-engine/tests/unit/__init__.py b/model-engine/tests/unit/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/tests/unit/api/__init__.py b/model-engine/tests/unit/api/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/tests/unit/api/conftest.py b/model-engine/tests/unit/api/conftest.py similarity index 76% rename from server/tests/unit/api/conftest.py rename to model-engine/tests/unit/api/conftest.py index f223c722..29454dca 100644 --- a/server/tests/unit/api/conftest.py +++ b/model-engine/tests/unit/api/conftest.py @@ -1,20 +1,26 @@ +import asyncio import datetime -from typing import Any, Dict, Iterator, Tuple +from typing import Any, Dict, Iterator, Optional, Tuple import pytest +import pytest_asyncio from fastapi import Depends, HTTPException from fastapi.security import HTTPBasicCredentials from fastapi.testclient import TestClient -from llm_engine_server.api.app import app -from llm_engine_server.api.dependencies import ( - AUTH, +from httpx import AsyncClient +from model_engine_server.api.app import app +from model_engine_server.api.dependencies import ( + basic_auth, get_external_interfaces, get_external_interfaces_read_only, + oauth2_scheme, verify_authentication, ) -from llm_engine_server.core.auth.authentication_repository import AuthenticationRepository, User -from llm_engine_server.core.auth.fake_authentication_repository import FakeAuthenticationRepository -from llm_engine_server.domain.entities import ( +from model_engine_server.core.auth.authentication_repository import AuthenticationRepository, User +from model_engine_server.core.auth.fake_authentication_repository import ( + FakeAuthenticationRepository, +) +from model_engine_server.domain.entities import ( BatchJob, BatchJobProgress, BatchJobRecord, @@ -39,10 +45,11 @@ PytorchFramework, StreamingEnhancedRunnableImageFlavor, TensorflowFramework, + Trigger, ZipArtifactFlavor, ) -from llm_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) @@ -59,18 +66,22 @@ def get_test_auth_repository() -> Iterator[AuthenticationRepository]: def fake_verify_authentication( - credentials: HTTPBasicCredentials = Depends(AUTH), + credentials: Optional[HTTPBasicCredentials] = Depends(basic_auth), + tokens: Optional[str] = Depends(oauth2_scheme), auth_repo: AuthenticationRepository = Depends(get_test_auth_repository), ) -> User: """ Verifies the authentication headers and returns a (user_id, team_id) auth tuple. Otherwise, raises a 401. """ - auth_user_id = credentials.username if credentials is not None else None - if not auth_user_id: - raise HTTPException(status_code=401, detail="No user id was passed in") - - auth = auth_repo.get_auth_from_user_id(user_id=auth_user_id) + if credentials is not None: + auth_username = credentials.username + elif tokens is not None: + auth_username = tokens + else: + raise HTTPException(status_code=401, detail="No authentication was passed in") + + auth = auth_repo.get_auth_from_username(username=auth_username) if not auth: raise HTTPException(status_code=401, detail="Could not authenticate user") @@ -87,6 +98,14 @@ def fake_auth(): app.dependency_overrides[verify_authentication] = {} +@pytest_asyncio.fixture(scope="session", autouse=True) +def event_loop(request): + """Create an instance of the default event loop for each test case.""" + loop = asyncio.get_event_loop_policy().new_event_loop() + yield loop + loop.close() + + @pytest.fixture def get_test_client_wrapper(get_repositories_generator_wrapper): def get_test_client( @@ -99,6 +118,11 @@ def get_test_client( fake_docker_image_batch_job_bundle_repository_contents=None, fake_docker_image_batch_job_gateway_contents=None, fake_llm_fine_tuning_service_contents=None, + fake_file_storage_gateway_contents=None, + fake_file_system_gateway_contents=None, + fake_trigger_repository_contents=None, + fake_cron_job_gateway_contents=None, + fake_sync_inference_content=None, ) -> TestClient: if fake_docker_image_batch_job_gateway_contents is None: fake_docker_image_batch_job_gateway_contents = {} @@ -116,6 +140,16 @@ def get_test_client( fake_model_bundle_repository_contents = {} if fake_llm_fine_tuning_service_contents is None: fake_llm_fine_tuning_service_contents = {} + if fake_file_storage_gateway_contents is None: + fake_file_storage_gateway_contents = {} + if fake_file_system_gateway_contents is None: + fake_file_system_gateway_contents = {} + if fake_trigger_repository_contents is None: + fake_trigger_repository_contents = {} + if fake_cron_job_gateway_contents is None: + fake_cron_job_gateway_contents = {} + if fake_sync_inference_content is None: + fake_sync_inference_content = {} app.dependency_overrides[get_external_interfaces] = get_repositories_generator_wrapper( fake_docker_repository_image_always_exists=fake_docker_repository_image_always_exists, fake_model_bundle_repository_contents=fake_model_bundle_repository_contents, @@ -126,6 +160,11 @@ def get_test_client( fake_docker_image_batch_job_bundle_repository_contents=fake_docker_image_batch_job_bundle_repository_contents, fake_docker_image_batch_job_gateway_contents=fake_docker_image_batch_job_gateway_contents, fake_llm_fine_tuning_service_contents=fake_llm_fine_tuning_service_contents, + fake_file_storage_gateway_contents=fake_file_storage_gateway_contents, + fake_file_system_gateway_contents=fake_file_system_gateway_contents, + fake_trigger_repository_contents=fake_trigger_repository_contents, + fake_cron_job_gateway_contents=fake_cron_job_gateway_contents, + fake_sync_inference_content=fake_sync_inference_content, ) app.dependency_overrides[get_external_interfaces_read_only] = app.dependency_overrides[ get_external_interfaces @@ -136,6 +175,75 @@ def get_test_client( return get_test_client +@pytest.fixture +def get_async_test_client_wrapper(get_repositories_generator_wrapper): + def get_async_test_client( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents=None, + fake_model_endpoint_record_repository_contents=None, + fake_model_endpoint_infra_gateway_contents=None, + fake_batch_job_record_repository_contents=None, + fake_batch_job_progress_gateway_contents=None, + fake_docker_image_batch_job_bundle_repository_contents=None, + fake_docker_image_batch_job_gateway_contents=None, + fake_llm_fine_tuning_service_contents=None, + fake_file_storage_gateway_contents=None, + fake_file_system_gateway_contents=None, + fake_trigger_repository_contents=None, + fake_cron_job_gateway_contents=None, + fake_sync_inference_content=None, + ) -> AsyncClient: + if fake_docker_image_batch_job_gateway_contents is None: + fake_docker_image_batch_job_gateway_contents = {} + if fake_docker_image_batch_job_bundle_repository_contents is None: + fake_docker_image_batch_job_bundle_repository_contents = {} + if fake_batch_job_progress_gateway_contents is None: + fake_batch_job_progress_gateway_contents = {} + if fake_batch_job_record_repository_contents is None: + fake_batch_job_record_repository_contents = {} + if fake_model_endpoint_infra_gateway_contents is None: + fake_model_endpoint_infra_gateway_contents = {} + if fake_model_endpoint_record_repository_contents is None: + fake_model_endpoint_record_repository_contents = {} + if fake_model_bundle_repository_contents is None: + fake_model_bundle_repository_contents = {} + if fake_llm_fine_tuning_service_contents is None: + fake_llm_fine_tuning_service_contents = {} + if fake_file_storage_gateway_contents is None: + fake_file_storage_gateway_contents = {} + if fake_file_system_gateway_contents is None: + fake_file_system_gateway_contents = {} + if fake_trigger_repository_contents is None: + fake_trigger_repository_contents = {} + if fake_cron_job_gateway_contents is None: + fake_cron_job_gateway_contents = {} + if fake_sync_inference_content is None: + fake_sync_inference_content = {} + app.dependency_overrides[get_external_interfaces] = get_repositories_generator_wrapper( + fake_docker_repository_image_always_exists=fake_docker_repository_image_always_exists, + fake_model_bundle_repository_contents=fake_model_bundle_repository_contents, + fake_model_endpoint_record_repository_contents=fake_model_endpoint_record_repository_contents, + fake_model_endpoint_infra_gateway_contents=fake_model_endpoint_infra_gateway_contents, + fake_batch_job_record_repository_contents=fake_batch_job_record_repository_contents, + fake_batch_job_progress_gateway_contents=fake_batch_job_progress_gateway_contents, + fake_docker_image_batch_job_bundle_repository_contents=fake_docker_image_batch_job_bundle_repository_contents, + fake_docker_image_batch_job_gateway_contents=fake_docker_image_batch_job_gateway_contents, + fake_llm_fine_tuning_service_contents=fake_llm_fine_tuning_service_contents, + fake_file_storage_gateway_contents=fake_file_storage_gateway_contents, + fake_file_system_gateway_contents=fake_file_system_gateway_contents, + fake_trigger_repository_contents=fake_trigger_repository_contents, + fake_cron_job_gateway_contents=fake_cron_job_gateway_contents, + fake_sync_inference_content=fake_sync_inference_content, + ) + app.dependency_overrides[get_external_interfaces_read_only] = app.dependency_overrides[ + get_external_interfaces + ] + client = AsyncClient(app=app, base_url="http://test") + return client + + return get_async_test_client + + @pytest.fixture def simple_client(get_test_client_wrapper) -> TestClient: """Returns a Client with no initial contents and a Docker repository that always returns True""" @@ -147,6 +255,7 @@ def simple_client(get_test_client_wrapper) -> TestClient: fake_batch_job_record_repository_contents={}, fake_batch_job_progress_gateway_contents={}, fake_docker_image_batch_job_bundle_repository_contents={}, + fake_trigger_repository_contents={}, ) return client @@ -505,7 +614,7 @@ def create_model_endpoint_request_async( "gpus": 1, "memory": "1G", "gpu_type": "nvidia-tesla-t4", - "storage": None, + "storage": "2G", "min_workers": 0, "max_workers": 5, "per_worker": 3, @@ -531,7 +640,7 @@ def create_model_endpoint_request_sync( "gpus": 1, "memory": "1G", "gpu_type": "nvidia-ampere-a10", - "storage": None, + "storage": "2G", "min_workers": 1, "max_workers": 5, "per_worker": 3, @@ -557,7 +666,7 @@ def create_model_endpoint_request_streaming( "gpus": 1, "memory": "1G", "gpu_type": "nvidia-ampere-a10", - "storage": None, + "storage": "2G", "min_workers": 1, "max_workers": 5, "per_worker": 1, @@ -583,7 +692,7 @@ def create_model_endpoint_request_streaming_invalid_bundle( "gpus": 1, "memory": "1G", "gpu_type": "nvidia-ampere-a10", - "storage": None, + "storage": "2G", "min_workers": 1, "max_workers": 5, "per_worker": 1, @@ -609,7 +718,7 @@ def create_model_endpoint_request_sync_invalid_streaming_bundle( "gpus": 1, "memory": "1G", "gpu_type": "nvidia-ampere-a10", - "storage": None, + "storage": "2G", "min_workers": 1, "max_workers": 5, "per_worker": 1, @@ -675,6 +784,7 @@ def model_endpoint_1( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -685,7 +795,7 @@ def model_endpoint_1( post_inference_hooks=["callback"], default_callback_url="http://www.example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( + root=CallbackBasicAuth( kind="basic", username="test_username", password="test_password", @@ -727,11 +837,12 @@ def model_endpoint_1( "unavailable_workers": 1, }, "resource_state": { - "cpus": "1", + "cpus": 1, "gpus": 1, "memory": "1G", "gpu_type": "nvidia-tesla-t4", "storage": "10G", + "nodes_per_worker": 1, "optimize_costs": True, }, "num_queued_items": 1, @@ -780,6 +891,7 @@ def model_endpoint_2( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -790,7 +902,18 @@ def model_endpoint_2( post_inference_hooks=None, default_callback_url=None, default_callback_auth=None, + billing_tags={ + "idempotencyKeyPrefix": "value1", + "product": "value2", + "type": "hi", + "subType": "hi", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "hi", + "payor": "hi", + "reference": {"referenceType": "hi", "referenceId": "hi"}, + }, user_id=test_api_key, + billing_queue="some:arn:for:something", ), ), image="test_image_2", @@ -827,6 +950,7 @@ def model_endpoint_2( "memory": "1G", "gpu_type": "nvidia-tesla-t4", "storage": "10G", + "nodes_per_worker": 1, "optimize_costs": False, }, "image": "test_image_2", @@ -1050,6 +1174,45 @@ def docker_image_batch_job_bundle_2_v1(test_api_key) -> Tuple[DockerImageBatchJo return batch_bundle, batch_bundle_json +@pytest.fixture +def docker_image_batch_job_bundle_3_v1(test_api_key) -> Tuple[DockerImageBatchJobBundle, Any]: + batch_bundle = DockerImageBatchJobBundle( + id="test_docker_image_batch_job_bundle_id_31", + created_at=datetime.datetime(2022, 1, 2), + name="test_docker_image_batch_job_bundle_3", + created_by=test_api_key, + owner=test_api_key, + image_repository="image_repository", + image_tag="image_tag_git_sha", + command=["python", "script3.py", "--arg1"], + env=dict(ENV1="VAL1", ENV2="VAL2"), + mount_location="/mount2/location/to/config", + cpus="3", + memory="5G", + storage="5G", + gpus=None, + gpu_type=None, + public=None, + ) + batch_bundle_json = { + "id": "test_docker_image_batch_job_bundle_id_31", + "name": "test_docker_image_batch_job_bundle_3", + "created_at": "2022-01-02T00:00:00", + "image_repository": "image_repository", + "image_tag": "image_tag_git_sha", + "command": ["python", "script3.py", "--arg1"], + "env": {"ENV1": "VAL1", "ENV2": "VAL2"}, + "mount_location": "/mount2/location/to/config", + "cpus": "3", + "memory": "5G", + "storage": "5G", + "gpus": None, + "gpu_type": None, + "public": None, + } + return batch_bundle, batch_bundle_json + + @pytest.fixture def create_docker_image_batch_job_request() -> Dict[str, Any]: return dict( @@ -1103,7 +1266,8 @@ def create_llm_model_endpoint_request_sync() -> Dict[str, Any]: "gpus": 2, "memory": "1G", "gpu_type": "nvidia-tesla-t4", - "storage": None, + "storage": "1Gi", + "nodes_per_worker": 1, "min_workers": 1, "max_workers": 5, "per_worker": 3, @@ -1114,14 +1278,104 @@ def create_llm_model_endpoint_request_sync() -> Dict[str, Any]: @pytest.fixture def completion_sync_request() -> Dict[str, Any]: - return {"prompts": ["what is 1+1?"], "max_new_tokens": 10, "temperature": 0.1} + return { + "prompt": "what is 1+1?", + "max_new_tokens": 10, + "temperature": 0.1, + } @pytest.fixture -def completion_sync_request_temperature_zero() -> Dict[str, Any]: - return {"prompts": ["what is 1+1?"], "max_new_tokens": 10, "temperature": 0} +def completion_stream_request() -> Dict[str, Any]: + return {"prompt": "what is 1+1?", "max_new_tokens": 10, "temperature": 0.1} @pytest.fixture -def completion_stream_request() -> Dict[str, Any]: - return {"prompt": "what is 1+1?", "max_new_tokens": 10, "temperature": 0.1} +def create_trigger_request() -> Dict[str, Any]: + return dict( + name="test_trigger_1", + cron_schedule="* * * * *", + bundle_id="test_docker_image_batch_job_bundle_id_31", + default_job_config={}, + default_job_metadata=dict(team="infra", product="my_product"), + ) + + +@pytest.fixture +def update_trigger_request() -> Dict[str, Any]: + return dict(cron_schedule="0 * * * *", suspend=True) + + +@pytest.fixture +def trigger_1(test_api_key) -> Tuple[Trigger, Any]: + trigger = Trigger( + id="test_trigger_id_1", + name="test_trigger_1", + owner=test_api_key, + created_by=test_api_key, + created_at=datetime.datetime(2022, 1, 2), + cron_schedule="* * * * *", + docker_image_batch_job_bundle_id="test_docker_image_batch_job_bundle_id_11", + default_job_config={}, + default_job_metadata=dict(team="infra", product="my_product_one"), + ) + trigger_json = { + "id": "test_trigger_id_1", + "name": "test_trigger_1", + "owner": "test_user_id", + "created_by": "test_user_id", + "created_at": "2022-01-02T00:00:00", + "cron_schedule": "* * * * *", + "docker_image_batch_job_bundle_id": "test_docker_image_batch_job_bundle_id_11", + "default_job_config": {}, + "default_job_metadata": {"team": "infra", "product": "my_product_one"}, + } + return trigger, trigger_json + + +@pytest.fixture +def trigger_2(test_api_key) -> Tuple[Trigger, Any]: + trigger = Trigger( + id="test_trigger_id_2", + name="test_trigger_2", + owner=test_api_key, + created_by=test_api_key, + created_at=datetime.datetime(2022, 2, 2), + cron_schedule="0 * * * *", + docker_image_batch_job_bundle_id="test_docker_image_batch_job_bundle_id_12", + default_job_config={}, + default_job_metadata=dict(team="infra", product="my_product_two"), + ) + trigger_json = { + "id": "test_trigger_id_2", + "name": "test_trigger_2", + "owner": "test_user_id", + "created_by": "test_user_id", + "created_at": "2022-02-02T00:00:00", + "cron_schedule": "0 * * * *", + "docker_image_batch_job_bundle_id": "test_docker_image_batch_job_bundle_id_12", + "default_job_config": {}, + "default_job_metadata": {"team": "infra", "product": "my_product_two"}, + } + return trigger, trigger_json + + +@pytest.fixture +def create_batch_completions_request() -> Dict[str, Any]: + return { + "input_data_path": "test_input_data_path", + "output_data_path": "test_output_data_path", + "content": { + "prompts": ["what is 1+1?"], + "max_new_tokens": 10, + "temperature": 0.1, + }, + "model_config": { + "model": "mpt-7b", + "checkpoint_path": "s3://test_checkpoint_path", + "labels": {}, + "num_shards": 2, + }, + "data_parallelism": 1, + "max_runtime_sec": 86400, + } diff --git a/server/tests/unit/api/test_app.py b/model-engine/tests/unit/api/test_app.py similarity index 100% rename from server/tests/unit/api/test_app.py rename to model-engine/tests/unit/api/test_app.py diff --git a/server/tests/unit/api/test_batch_jobs.py b/model-engine/tests/unit/api/test_batch_jobs.py similarity index 79% rename from server/tests/unit/api/test_batch_jobs.py rename to model-engine/tests/unit/api/test_batch_jobs.py index 87216d9a..426bd1d3 100644 --- a/server/tests/unit/api/test_batch_jobs.py +++ b/model-engine/tests/unit/api/test_batch_jobs.py @@ -2,9 +2,15 @@ import pytest from fastapi.testclient import TestClient -from llm_engine_server.domain.entities import BatchJob, GpuType, ModelBundle, ModelEndpoint -from llm_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.domain.entities import ( + BatchJob, + DockerImageBatchJob, + GpuType, + ModelBundle, + ModelEndpoint, + Trigger, +) +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) @@ -35,6 +41,33 @@ def test_create_batch_job_success( assert "job_id" in response.json() +@pytest.mark.skip(reason="TODO: team validation is currently disabled") +def test_create_batch_job_invalid_team_returns_400( + model_bundle_1_v1: Tuple[ModelBundle, Any], + create_batch_job_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={ + model_bundle_1_v1[0].id: model_bundle_1_v1[0], + }, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + create_batch_job_request["labels"]["team"] = "invalid_team" + response = client.post( + "/v1/batch-jobs", + auth=(test_api_key, ""), + json=create_batch_job_request, + ) + assert response.status_code == 400 + + def test_create_batch_job_bundle_not_found_returns_404( create_batch_job_request: Dict[str, Any], test_api_key: str, @@ -260,9 +293,9 @@ def test_create_docker_image_batch_job_unauthorized( } ) del create_docker_image_batch_job_request["docker_image_batch_job_bundle_name"] - create_docker_image_batch_job_request[ - "docker_image_batch_job_bundle_id" - ] = docker_image_batch_job_bundle_1_v1[0].id + create_docker_image_batch_job_request["docker_image_batch_job_bundle_id"] = ( + docker_image_batch_job_bundle_1_v1[0].id + ) response = client.post( "/v1/docker-image-batch-jobs", auth=(test_api_key_2, ""), @@ -302,9 +335,9 @@ def test_create_docker_image_batch_job_bundle_id_and_name( docker_image_batch_job_bundle_1_v1[0].id: docker_image_batch_job_bundle_1_v1[0] } ) - create_docker_image_batch_job_request[ - "docker_image_batch_job_bundle_id" - ] = docker_image_batch_job_bundle_1_v1[0].id + create_docker_image_batch_job_request["docker_image_batch_job_bundle_id"] = ( + docker_image_batch_job_bundle_1_v1[0].id + ) response = client.post( "/v1/docker-image-batch-jobs", auth=(test_api_key, ""), @@ -393,6 +426,26 @@ def test_create_docker_image_batch_job_no_image( assert response.status_code == 404 +def test_create_docker_image_batch_job_invalid_time_limit( + test_api_key: str, + get_test_client_wrapper, + create_docker_image_batch_job_request: Dict[str, Any], + docker_image_batch_job_bundle_1_v1: Tuple[DockerImageBatchJobBundle, Any], +): + client = get_test_client_wrapper( + fake_docker_image_batch_job_bundle_repository_contents={ + docker_image_batch_job_bundle_1_v1[0].id: docker_image_batch_job_bundle_1_v1[0] + } + ) + create_docker_image_batch_job_request["override_job_max_runtime_s"] = -1 + response = client.post( + "/v1/docker-image-batch-jobs", + auth=(test_api_key, ""), + json=create_docker_image_batch_job_request, + ) + assert response.status_code == 400 + + def test_get_docker_image_batch_job_success( test_api_key: str, get_test_client_wrapper, @@ -440,6 +493,83 @@ def test_get_docker_image_batch_job_not_exist( assert response.status_code == 404 +def test_list_jobs_success( + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + "/v1/docker-image-batch-jobs", + auth=(test_api_key, ""), + ) + assert response.status_code == 200 + assert "jobs" in response.json() + + +def test_list_jobs_by_trigger_success( + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + f"/v1/docker-image-batch-jobs?trigger_id={trigger_1[0].id}", + auth=(test_api_key, ""), + ) + assert response.status_code == 200 + assert "jobs" in response.json() + + +def test_list_jobs_by_trigger_not_found_returns_404( + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + "/v1/docker-image-batch-jobs?trigger_id=some_trigger_id", + auth=(test_api_key, ""), + ) + assert response.status_code == 404 + + +def test_list_jobs_by_trigger_unauthorized_returns_404( + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + f"/v1/docker-image-batch-jobs?trigger_id={trigger_1[0].id}", + auth=("some_invalid_id", ""), + ) + assert response.status_code == 404 + + def test_update_docker_image_batch_job_noop( test_api_key: str, get_test_client_wrapper, diff --git a/server/tests/unit/api/test_docker_image_batch_job_bundles.py b/model-engine/tests/unit/api/test_docker_image_batch_job_bundles.py similarity index 99% rename from server/tests/unit/api/test_docker_image_batch_job_bundles.py rename to model-engine/tests/unit/api/test_docker_image_batch_job_bundles.py index 49e4d09a..2aa12a30 100644 --- a/server/tests/unit/api/test_docker_image_batch_job_bundles.py +++ b/model-engine/tests/unit/api/test_docker_image_batch_job_bundles.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Tuple from fastapi.testclient import TestClient -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) diff --git a/model-engine/tests/unit/api/test_llms.py b/model-engine/tests/unit/api/test_llms.py new file mode 100644 index 00000000..9e8fbc95 --- /dev/null +++ b/model-engine/tests/unit/api/test_llms.py @@ -0,0 +1,290 @@ +import json +from typing import Any, Dict, Tuple +from unittest import mock + +import pytest +from model_engine_server.common.dtos.llms import GetLLMModelEndpointV1Response +from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus +from model_engine_server.domain.entities import ModelEndpoint +from tests.unit.domain.test_llm_use_cases import mocked__get_latest_batch_tag + +from ..conftest import mocked__get_recommended_hardware_config_map + + +def test_create_llm_model_endpoint_success( + create_llm_model_endpoint_request_sync: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + response_1 = client.post( + "/v1/llm/model-endpoints", + auth=(test_api_key, ""), + json=create_llm_model_endpoint_request_sync, + ) + assert response_1.status_code == 200 + + +def test_list_model_endpoints_success( + llm_model_endpoint_async: Tuple[ModelEndpoint, Any], + model_endpoint_2: Tuple[ModelEndpoint, Any], + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_model_endpoint_record_repository_contents={ + llm_model_endpoint_async[0].record.id: llm_model_endpoint_async[0].record, + }, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_async[0] + .infra_state.deployment_name: llm_model_endpoint_async[0] + .infra_state, + model_endpoint_2[0].infra_state.deployment_name: model_endpoint_2[0].infra_state, + }, + ) + response_1 = client.get( + "/v1/llm/model-endpoints?order_by=newest", + auth=("no_user", ""), + ) + expected_model_endpoint_1 = json.loads( + GetLLMModelEndpointV1Response.parse_obj(llm_model_endpoint_async[1]).json() + ) + assert response_1.status_code == 200 + assert response_1.json() == {"model_endpoints": [expected_model_endpoint_1]} + + +def test_get_llm_model_endpoint_success( + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + model_endpoint_2: Tuple[ModelEndpoint, Any], + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_model_endpoint_record_repository_contents={ + llm_model_endpoint_sync[0].record.id: llm_model_endpoint_sync[0].record, + }, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_sync[0] + .infra_state.deployment_name: llm_model_endpoint_sync[0] + .infra_state, + model_endpoint_2[0].infra_state.deployment_name: model_endpoint_2[0].infra_state, + }, + ) + response_1 = client.get( + f"/v1/llm/model-endpoints/{llm_model_endpoint_sync[0].record.name}", + auth=("no_user", ""), + ) + expected_model_endpoint_1 = json.loads( + GetLLMModelEndpointV1Response.parse_obj(llm_model_endpoint_sync[1]).json() + ) + assert response_1.status_code == 200 + assert response_1.json() == expected_model_endpoint_1 + + +def test_completion_sync_success( + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + completion_sync_request: Dict[str, Any], + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={ + llm_model_endpoint_sync[0].record.id: llm_model_endpoint_sync[0].record, + }, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_sync[0] + .infra_state.deployment_name: llm_model_endpoint_sync[0] + .infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + fake_sync_inference_content=SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": """{ + "text": "output", + "count_prompt_tokens": 1, + "count_output_tokens": 1 + }""" + }, + traceback=None, + ), + ) + response_1 = client.post( + f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}", + auth=("no_user", ""), + json=completion_sync_request, + ) + assert response_1.status_code == 200 + assert response_1.json()["output"] == { + "text": "output", + "num_completion_tokens": 1, + "num_prompt_tokens": 1, + "tokens": None, + } + assert response_1.json().keys() == {"output", "request_id"} + + +def test_completion_sync_endpoint_not_found_returns_404( + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + completion_sync_request: Dict[str, Any], + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_sync[0] + .infra_state.deployment_name: llm_model_endpoint_sync[0] + .infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + response_1 = client.post( + f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}", + auth=("no_user", ""), + json=completion_sync_request, + ) + assert response_1.status_code == 404 + + +@pytest.mark.asyncio +async def test_completion_stream_success( + llm_model_endpoint_streaming: ModelEndpoint, + completion_stream_request: Dict[str, Any], + get_async_test_client_wrapper, +): # pragma: no cover + async with get_async_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={ + llm_model_endpoint_streaming.record.id: llm_model_endpoint_streaming.record, + }, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) as client: + with mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=5, + ): + async with client.stream( + method="POST", + url=f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", + auth=("no_user", ""), + json=completion_stream_request, + ) as r: + assert r.status_code == 200 + count = 0 + async for message in r.aiter_bytes(): + decoded_message = message.decode("utf-8") + assert decoded_message.startswith( + "data: " + ), f"SSE does not start with 'data: ', message is '{decoded_message}'" + + # strip 'data: ' prefix from Server-sent events format + json_str = decoded_message[len("data: ") :] + parsed_data = json.loads(json_str.strip()) + assert parsed_data["request_id"] is not None + assert parsed_data["output"] is None + assert parsed_data["error"] is None + count += 1 + assert count == 1 + + +def test_completion_stream_endpoint_not_found_returns_404( + llm_model_endpoint_streaming: ModelEndpoint, + completion_stream_request: Dict[str, Any], + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + with client.stream( + method="POST", + url=f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", + auth=("no_user", ""), + json=completion_stream_request, + ) as r: + assert r.status_code == 404 + + +def test_completion_stream_misc_server_error_returns_500( + llm_model_endpoint_streaming: ModelEndpoint, + completion_stream_request: Dict[str, Any], + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={ + llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + with mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.CompletionStreamV1UseCase.execute", + ) as mock_stream_usecase: + mock_stream_usecase.side_effect = RuntimeError("Some server side runtime error.") + with client.stream( + method="POST", + url=f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", + auth=("no_user", ""), + json=completion_stream_request, + ) as r: + assert r.status_code == 500 + + +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_recommended_hardware_config_map", + mocked__get_recommended_hardware_config_map(), +) +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_batch_tag", + mocked__get_latest_batch_tag(), +) +def test_create_batch_completions_success( + create_batch_completions_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + response_1 = client.post( + "/v1/llm/batch-completions", + auth=(test_api_key, ""), + json=create_batch_completions_request, + ) + assert response_1.status_code == 200 diff --git a/server/tests/unit/api/test_model_bundles.py b/model-engine/tests/unit/api/test_model_bundles.py similarity index 96% rename from server/tests/unit/api/test_model_bundles.py rename to model-engine/tests/unit/api/test_model_bundles.py index 66346f5f..c54e110c 100644 --- a/server/tests/unit/api/test_model_bundles.py +++ b/model-engine/tests/unit/api/test_model_bundles.py @@ -2,7 +2,7 @@ import pytest from fastapi.testclient import TestClient -from llm_engine_server.domain.entities import ModelBundle +from model_engine_server.domain.entities import ModelBundle @pytest.mark.parametrize("version", ["v1", "v2"]) @@ -84,10 +84,7 @@ def test_clone_model_bundle_success( response = client.post( f"/{version}/model-bundles/clone-with-changes", auth=(test_api_key, ""), - json={ - "original_model_bundle_id": model_bundle_1_v1[0].id, - "app_config": {"foo": "bar"}, - }, + json={"original_model_bundle_id": model_bundle_1_v1[0].id, "app_config": {"foo": "bar"}}, ) assert response.status_code == 200 response_json = response.json() @@ -116,10 +113,7 @@ def test_clone_model_bundle_unauthorized_returns_404( response = client.post( f"/{version}/model-bundles/clone-with-changes", auth=(test_api_key_2, ""), # Not the owner, should be unauthorized - json={ - "original_model_bundle_id": model_bundle_1_v1[0].id, - "app_config": {"foo": "bar"}, - }, + json={"original_model_bundle_id": model_bundle_1_v1[0].id, "app_config": {"foo": "bar"}}, ) assert response.status_code == 404 @@ -146,10 +140,7 @@ def test_clone_model_bundle_not_found_returns_404( response = client.post( f"/{version}/model-bundles/clone-with-changes", auth=(test_api_key, ""), - json={ - "original_model_bundle_id": "unknown model bundle id", - "app_config": {"foo": "bar"}, - }, + json={"original_model_bundle_id": "unknown model bundle id", "app_config": {"foo": "bar"}}, ) assert response.status_code == 404 diff --git a/server/tests/unit/api/test_model_endpoints.py b/model-engine/tests/unit/api/test_model_endpoints.py similarity index 85% rename from server/tests/unit/api/test_model_endpoints.py rename to model-engine/tests/unit/api/test_model_endpoints.py index 8961bbdc..1cc02f0b 100644 --- a/server/tests/unit/api/test_model_endpoints.py +++ b/model-engine/tests/unit/api/test_model_endpoints.py @@ -3,8 +3,9 @@ import pytest from fastapi.testclient import TestClient -from llm_engine_server.common.dtos.model_endpoints import GetModelEndpointV1Response -from llm_engine_server.domain.entities import ModelBundle, ModelEndpoint, ModelEndpointStatus +from model_engine_server.common.dtos.model_endpoints import GetModelEndpointV1Response +from model_engine_server.domain.entities import ModelBundle, ModelEndpoint, ModelEndpointStatus +from model_engine_server.domain.use_cases.model_endpoint_use_cases import DEFAULT_DISALLOWED_TEAMS def test_create_model_endpoint_success( @@ -40,6 +41,42 @@ def test_create_model_endpoint_success( assert response_2.status_code == 200 +def test_create_model_endpoint_invalid_team_returns_400( + model_bundle_1_v1: Tuple[ModelBundle, Any], + create_model_endpoint_request_sync: Dict[str, Any], + create_model_endpoint_request_async: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={ + model_bundle_1_v1[0].id: model_bundle_1_v1[0], + }, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + invalid_team_name = DEFAULT_DISALLOWED_TEAMS[0] + create_model_endpoint_request_sync["labels"]["team"] = invalid_team_name + response_1 = client.post( + "/v1/model-endpoints", + auth=(test_api_key, ""), + json=create_model_endpoint_request_sync, + ) + assert response_1.status_code == 400 + + create_model_endpoint_request_async["labels"]["team"] = invalid_team_name + response_2 = client.post( + "/v1/model-endpoints", + auth=(test_api_key, ""), + json=create_model_endpoint_request_async, + ) + assert response_2.status_code == 400 + + def test_create_model_endpoint_invalid_streaming_bundle_returns_400( model_bundle_1_v1: Tuple[ModelBundle, Any], create_model_endpoint_request_streaming_invalid_bundle: Dict[str, Any], @@ -188,6 +225,32 @@ def test_create_model_endpoint_endpoint_already_exists_returns_400( assert response_1.status_code == 400 +def test_create_model_endpoint_multinode_from_nonmultinode_bundle_returns_400( + model_bundle_1_v1: Tuple[ModelBundle, Any], + create_model_endpoint_request_sync: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, +): + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={ + model_bundle_1_v1[0].id: model_bundle_1_v1[0], + }, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + create_model_endpoint_request_sync["nodes_per_worker"] = 2 + response_1 = client.post( + "/v1/model-endpoints", + auth=(test_api_key, ""), + json=create_model_endpoint_request_sync, + ) + assert response_1.status_code == 400 + + def test_list_model_endpoints( model_bundle_1_v1: Tuple[ModelBundle, Any], model_endpoint_1: Tuple[ModelEndpoint, Any], @@ -358,6 +421,42 @@ def test_update_model_endpoint_by_id_success( assert response.json()["endpoint_creation_task_id"] +def test_update_model_endpoint_by_id_invalid_team_returns_400( + model_bundle_1_v1: Tuple[ModelBundle, Any], + model_endpoint_1: Tuple[ModelEndpoint, Any], + update_model_endpoint_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, +): + assert model_endpoint_1[0].infra_state is not None + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={ + model_bundle_1_v1[0].id: model_bundle_1_v1[0], + }, + fake_model_endpoint_record_repository_contents={ + model_endpoint_1[0].record.id: model_endpoint_1[0].record, + }, + fake_model_endpoint_infra_gateway_contents={ + model_endpoint_1[0].infra_state.deployment_name: model_endpoint_1[0].infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + invalid_team_name = DEFAULT_DISALLOWED_TEAMS[0] + update_model_endpoint_request["labels"] = { + "team": invalid_team_name, + "product": "my_product", + } + response = client.put( + "/v1/model-endpoints/test_model_endpoint_id_1", + auth=(test_api_key, ""), + json=update_model_endpoint_request, + ) + assert response.status_code == 400 + + def test_update_model_endpoint_by_id_endpoint_not_authorized_returns_404( model_bundle_1_v1: Tuple[ModelBundle, Any], model_endpoint_1: Tuple[ModelEndpoint, Any], diff --git a/server/tests/unit/api/test_model_endpoints_docs.py b/model-engine/tests/unit/api/test_model_endpoints_docs.py similarity index 97% rename from server/tests/unit/api/test_model_endpoints_docs.py rename to model-engine/tests/unit/api/test_model_endpoints_docs.py index 5ee1451b..04828d05 100644 --- a/server/tests/unit/api/test_model_endpoints_docs.py +++ b/model-engine/tests/unit/api/test_model_endpoints_docs.py @@ -1,6 +1,6 @@ from typing import Any, Tuple -from llm_engine_server.domain.entities import ModelBundle, ModelEndpoint +from model_engine_server.domain.entities import ModelBundle, ModelEndpoint def test_model_endpoints_schema_success( diff --git a/server/tests/unit/api/test_tasks.py b/model-engine/tests/unit/api/test_tasks.py similarity index 83% rename from server/tests/unit/api/test_tasks.py rename to model-engine/tests/unit/api/test_tasks.py index 360658e3..f2cee1bb 100644 --- a/server/tests/unit/api/test_tasks.py +++ b/model-engine/tests/unit/api/test_tasks.py @@ -1,13 +1,15 @@ from typing import Any, Dict, Tuple from unittest.mock import AsyncMock, MagicMock, patch -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.domain_exceptions import ( +import pytest +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.domain.entities import ModelBundle, ModelEndpoint +from model_engine_server.domain.exceptions import ( + InvalidRequestException, ObjectNotAuthorizedException, ObjectNotFoundException, + UpstreamServiceError, ) -from llm_engine_server.domain.entities import ModelBundle, ModelEndpoint -from llm_engine_server.domain.exceptions import UpstreamServiceError def test_create_async_task_success( @@ -104,6 +106,43 @@ def test_create_async_task_raises_404_not_found( assert response.status_code == 404 +def test_create_async_task_raises_400_invalid_requests( + model_bundle_1_v1: Tuple[ModelBundle, Any], + model_endpoint_1: Tuple[ModelEndpoint, Any], + endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]], + test_api_key: str, + get_test_client_wrapper, +): + assert model_endpoint_1[0].infra_state is not None + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={ + model_bundle_1_v1[0].id: model_bundle_1_v1[0], + }, + fake_model_endpoint_record_repository_contents={ + model_endpoint_1[0].record.id: model_endpoint_1[0].record, + }, + fake_model_endpoint_infra_gateway_contents={ + model_endpoint_1[0].infra_state.deployment_name: model_endpoint_1[0].infra_state, + }, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={}, + ) + mock_use_case = MagicMock() + mock_use_case.return_value.execute = MagicMock(side_effect=InvalidRequestException) + with patch( + "model_engine_server.api.tasks_v1.CreateAsyncInferenceTaskV1UseCase", + mock_use_case, + ): + response = client.post( + "/v1/async-tasks?model_endpoint_id=invalid_model_endpoint_id", + auth=(test_api_key, ""), + json=endpoint_predict_request_1[1], + ) + assert response.status_code == 400 + + def test_get_async_task_success( model_bundle_1_v1: Tuple[ModelBundle, Any], model_endpoint_1: Tuple[ModelEndpoint, Any], @@ -158,7 +197,7 @@ def test_get_async_task_raises_404_object_not_found( mock_use_case = MagicMock() mock_use_case.return_value.execute = MagicMock(side_effect=ObjectNotFoundException) with patch( - "llm_engine_server.api.tasks_v1.GetAsyncInferenceTaskV1UseCase", + "model_engine_server.api.tasks_v1.GetAsyncInferenceTaskV1UseCase", mock_use_case, ): response = client.get( @@ -193,7 +232,7 @@ def test_get_async_task_raises_404_object_not_authorized( mock_use_case = MagicMock() mock_use_case.return_value.execute = MagicMock(side_effect=ObjectNotAuthorizedException) with patch( - "llm_engine_server.api.tasks_v1.GetAsyncInferenceTaskV1UseCase", + "model_engine_server.api.tasks_v1.GetAsyncInferenceTaskV1UseCase", mock_use_case, ): response = client.get( @@ -325,7 +364,7 @@ def test_create_sync_task_returns_failure( side_effect=UpstreamServiceError(400, b"test_content") ) with patch( - "llm_engine_server.api.tasks_v1.CreateSyncInferenceTaskV1UseCase", + "model_engine_server.api.tasks_v1.CreateSyncInferenceTaskV1UseCase", mock_use_case, ): response = client.post( @@ -337,15 +376,16 @@ def test_create_sync_task_returns_failure( assert response.json()["status"] == "FAILURE" -def test_create_streaming_task_success( +@pytest.mark.asyncio +async def test_create_streaming_task_success( model_bundle_5: ModelBundle, model_endpoint_streaming: ModelEndpoint, endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]], test_api_key: str, - get_test_client_wrapper, + get_async_test_client_wrapper, ): assert model_endpoint_streaming.infra_state is not None - client = get_test_client_wrapper( + async with get_async_test_client_wrapper( fake_docker_repository_image_always_exists=True, fake_model_bundle_repository_contents={ model_bundle_5.id: model_bundle_5, @@ -359,15 +399,18 @@ def test_create_streaming_task_success( fake_batch_job_record_repository_contents={}, fake_batch_job_progress_gateway_contents={}, fake_docker_image_batch_job_bundle_repository_contents={}, - ) - response = client.post( - f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}", - auth=(test_api_key, ""), - json=endpoint_predict_request_1[1], - ) - assert response.status_code == 200 - count = 0 - for message in response: - assert message == b'data: {"status": "SUCCESS", "result": null, "traceback": null}\r\n\r\n' - count += 1 - assert count == 1 + ) as client: + async with client.stream( + method="POST", + url=f"/v1/streaming-tasks?model_endpoint_id={model_endpoint_streaming.record.id}", + auth=(test_api_key, ""), + json=endpoint_predict_request_1[1], + ) as response: + assert response.status_code == 200 + count = 0 + async for message in response.aiter_bytes(): + assert ( + message == b'data: {"status":"SUCCESS","result":null,"traceback":null}\r\n\r\n' + ) + count += 1 + assert count == 1 diff --git a/model-engine/tests/unit/api/test_triggers.py b/model-engine/tests/unit/api/test_triggers.py new file mode 100644 index 00000000..ea9170ba --- /dev/null +++ b/model-engine/tests/unit/api/test_triggers.py @@ -0,0 +1,312 @@ +from typing import Any, Dict, Tuple + +from fastapi.testclient import TestClient +from model_engine_server.domain.entities import Trigger +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( + DockerImageBatchJobBundle, +) + + +def test_create_trigger_success( + create_trigger_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, + docker_image_batch_job_bundle_3_v1: Tuple[DockerImageBatchJobBundle, Any], +): + # populate docker image batch bundle repo + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={ + docker_image_batch_job_bundle_3_v1[0].id: docker_image_batch_job_bundle_3_v1[0], + }, + ) + + response_1 = client.post( + "/v1/triggers", + auth=(test_api_key, ""), + json=create_trigger_request, + ) + assert response_1.status_code == 200 + assert "trigger_id" in response_1.json() + + +def test_create_trigger_batch_bundle_not_found_returns_404( + create_trigger_request: Dict[str, Any], + test_api_key: str, + simple_client: TestClient, +): + response_1 = simple_client.post( + "/v1/triggers", + auth=(test_api_key, ""), + json=create_trigger_request, + ) + assert response_1.status_code == 404 + + +def test_create_trigger_batch_bundle_unauthorized_returns_400( + create_trigger_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, + docker_image_batch_job_bundle_3_v1: Tuple[DockerImageBatchJobBundle, Any], +): + # populate docker image batch bundle repo + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={ + docker_image_batch_job_bundle_3_v1[0].id: docker_image_batch_job_bundle_3_v1[0], + }, + ) + + response_1 = client.post( + "/v1/triggers", + auth=("some_invalid_id", ""), + json=create_trigger_request, + ) + assert response_1.status_code == 404 + + +def test_create_trigger_bad_cron_returns_400( + create_trigger_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, + docker_image_batch_job_bundle_3_v1: Tuple[DockerImageBatchJobBundle, Any], +): + # populate docker image batch bundle repo + client = get_test_client_wrapper( + fake_docker_repository_image_always_exists=True, + fake_model_bundle_repository_contents={}, + fake_model_endpoint_record_repository_contents={}, + fake_model_endpoint_infra_gateway_contents={}, + fake_batch_job_record_repository_contents={}, + fake_batch_job_progress_gateway_contents={}, + fake_docker_image_batch_job_bundle_repository_contents={ + docker_image_batch_job_bundle_3_v1[0].id: docker_image_batch_job_bundle_3_v1[0], + }, + ) + + create_trigger_request["cron_schedule"] = "field is wrong" + response_1 = client.post( + "/v1/triggers", + auth=(test_api_key, ""), + json=create_trigger_request, + ) + assert response_1.status_code == 400 + + +def test_list_triggers_success( + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + "/v1/triggers", + auth=(test_api_key, ""), + ) + assert response.status_code == 200 + assert response.json() == { + "triggers": [trigger_1[1], trigger_2[1]], + } + + +def test_get_trigger_success( + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + f"/v1/triggers/{trigger_1[0].id}", + auth=(test_api_key, ""), + ) + assert response.status_code == 200 + assert response.json() == trigger_1[1] + + +def test_get_trigger_not_found_returns_404( + test_api_key: str, + simple_client: TestClient, +): + response = simple_client.get( + "/v1/triggers/some_trigger_id", + auth=(test_api_key, ""), + ) + assert response.status_code == 404 + + +def test_get_trigger_unauthorized_returns_404( + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.get( + f"/v1/triggers/{trigger_1[0].id}", + auth=("some_invalid_id", ""), + ) + assert response.status_code == 404 + + +def test_update_trigger_success( + update_trigger_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.put( + f"/v1/triggers/{trigger_1[0].id}", + auth=(test_api_key, ""), + json=update_trigger_request, + ) + assert response.json().get("success") + + response = client.get( + f"/v1/triggers/{trigger_1[0].id}", + auth=(test_api_key, ""), + ) + assert response.status_code == 200 + assert response.json().get("cron_schedule") == "0 * * * *" + + +def test_update_trigger_not_found_returns_404( + update_trigger_request: Dict[str, Any], + test_api_key: str, + simple_client: TestClient, +): + response = simple_client.put( + "/v1/triggers/some_trigger_id", + auth=(test_api_key, ""), + json=update_trigger_request, + ) + assert response.status_code == 404 + + +def test_update_trigger_unauthorized_returns_404( + update_trigger_request: Dict[str, Any], + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.put( + f"/v1/triggers/{trigger_1[0].id}", + auth=("some_invalid_id", ""), + json=update_trigger_request, + ) + assert response.status_code == 404 + + +def test_update_trigger_bad_cron_returns_400( + update_trigger_request: Dict[str, Any], + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + + update_trigger_request["cron_schedule"] = "field is wrong" + response = client.put( + f"/v1/triggers/{trigger_1[0].id}", + auth=(test_api_key, ""), + json=update_trigger_request, + ) + assert response.status_code == 400 + + +def test_delete_trigger_success( + test_api_key: str, + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.delete( + f"/v1/triggers/{trigger_1[0].id}", + auth=(test_api_key, ""), + ) + assert response.json().get("success") + + response = client.get( + f"/v1/triggers/{trigger_1[0].id}", + auth=(test_api_key, ""), + ) + assert response.status_code == 404 + + +def test_delete_trigger_not_found_returns_404( + test_api_key: str, + simple_client: TestClient, +): + response = simple_client.delete( + "/v1/triggers/some_trigger_id", + auth=(test_api_key, ""), + ) + assert response.status_code == 404 + + +def test_delete_trigger_unauthorized_returns_404( + get_test_client_wrapper, + trigger_1: Tuple[Trigger, Any], + trigger_2: Tuple[Trigger, Any], +): + client = get_test_client_wrapper( + fake_trigger_repository_contents={ + trigger_1[0].id: trigger_1[0], + trigger_2[0].id: trigger_2[0], + }, + ) + response = client.delete( + f"/v1/triggers/{trigger_1[0].id}", + auth=("some_invalid_id", ""), + ) + assert response.status_code == 404 diff --git a/model-engine/tests/unit/common/__init__.py b/model-engine/tests/unit/common/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/server/tests/unit/common/test_batch_jobs_dtos.py b/model-engine/tests/unit/common/test_batch_jobs_dtos.py similarity index 88% rename from server/tests/unit/common/test_batch_jobs_dtos.py rename to model-engine/tests/unit/common/test_batch_jobs_dtos.py index b5f704f0..2ba5499d 100644 --- a/server/tests/unit/common/test_batch_jobs_dtos.py +++ b/model-engine/tests/unit/common/test_batch_jobs_dtos.py @@ -1,4 +1,4 @@ -from llm_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests def test_create_docker_image_batch_job_resource_requests_merge_requests(): @@ -24,10 +24,10 @@ def test_create_docker_image_batch_job_resource_requests_merge_requests(): # Test merging default = CreateDockerImageBatchJobResourceRequests(cpus=0.5) override = CreateDockerImageBatchJobResourceRequests( - memory="100Mi", gpus=1, gpu_type="nvidia-a100", storage="10Gi" + memory="100Mi", gpus=1, gpu_type="nvidia-ampere-a100", storage="10Gi" ) expected = CreateDockerImageBatchJobResourceRequests( - cpus=0.5, memory="100Mi", gpus=1, gpu_type="nvidia-a100", storage="10Gi" + cpus=0.5, memory="100Mi", gpus=1, gpu_type="nvidia-ampere-a100", storage="10Gi" ) actual = CreateDockerImageBatchJobResourceRequests.merge_requests(default, override) assert expected == actual diff --git a/server/tests/unit/conftest.py b/model-engine/tests/unit/conftest.py similarity index 69% rename from server/tests/unit/conftest.py rename to model-engine/tests/unit/conftest.py index 705cf468..d96c86fa 100644 --- a/server/tests/unit/conftest.py +++ b/model-engine/tests/unit/conftest.py @@ -11,40 +11,42 @@ List, Optional, Sequence, + Set, Tuple, ) +from unittest import mock from unittest.mock import mock_open from uuid import uuid4 import pytest -from llm_engine_server.api.dependencies import ExternalInterfaces -from llm_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME -from llm_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests -from llm_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse -from llm_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.common.dtos.model_endpoints import ( +from model_engine_server.api.dependencies import ExternalInterfaces +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse +from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.common.dtos.model_endpoints import ( BrokerType, CpuSpecificationType, GpuType, ModelEndpointOrderBy, StorageSpecificationType, ) -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.common.dtos.tasks import ( +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.dtos.tasks import ( CreateAsyncTaskV1Response, EndpointPredictV1Request, GetAsyncTaskV1Response, + SyncEndpointPredictV1Request, SyncEndpointPredictV1Response, TaskStatus, ) -from llm_engine_server.common.settings import generate_destination -from llm_engine_server.core.domain_exceptions import ObjectNotFoundException -from llm_engine_server.core.fake_notification_gateway import FakeNotificationGateway -from llm_engine_server.db.endpoint_row_lock import get_lock_key -from llm_engine_server.db.models import BatchJob as OrmBatchJob -from llm_engine_server.db.models import Endpoint as OrmModelEndpoint -from llm_engine_server.domain.entities import ( +from model_engine_server.common.settings import generate_destination +from model_engine_server.core.fake_notification_gateway import FakeNotificationGateway +from model_engine_server.db.endpoint_row_lock import get_lock_key +from model_engine_server.db.models import BatchJob as OrmBatchJob +from model_engine_server.db.models import Endpoint as OrmModelEndpoint +from model_engine_server.domain.entities import ( BatchJob, BatchJobProgress, BatchJobRecord, @@ -54,6 +56,9 @@ CallbackBasicAuth, CloudpickleArtifactFlavor, CustomFramework, + FileMetadata, + FineTuneHparamValueType, + LLMFineTuneEvent, ModelBundle, ModelBundleEnvironmentParams, ModelBundleFlavors, @@ -73,63 +78,91 @@ RunnableImageFlavor, StreamingEnhancedRunnableImageFlavor, TensorflowFramework, + Trigger, TritonEnhancedRunnableImageFlavor, ZipArtifactFlavor, ) -from llm_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.domain.entities.batch_job_entity import DockerImageBatchJob +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.domain.gateways import ( +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate +from model_engine_server.domain.exceptions import ( + EndpointResourceInfraException, + ObjectNotFoundException, +) +from model_engine_server.domain.gateways import ( AsyncModelEndpointInferenceGateway, + CronJobGateway, DockerImageBatchJobGateway, + FileStorageGateway, + InferenceAutoscalingMetricsGateway, + LLMArtifactGateway, StreamingModelEndpointInferenceGateway, SyncModelEndpointInferenceGateway, TaskQueueGateway, ) -from llm_engine_server.domain.repositories import ( +from model_engine_server.domain.repositories import ( DockerImageBatchJobBundleRepository, DockerRepository, + LLMFineTuneEventsRepository, ModelBundleRepository, + TokenizerRepository, + TriggerRepository, +) +from model_engine_server.domain.services import ( + LLMFineTuningService, + LLMModelEndpointService, + ModelEndpointService, ) -from llm_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService -from llm_engine_server.domain.services.llm_fine_tuning_service import LLMFineTuningService -from llm_engine_server.infra.gateways import ( +from model_engine_server.inference.domain.gateways.streaming_storage_gateway import ( + StreamingStorageGateway, +) +from model_engine_server.infra.gateways import ( BatchJobOrchestrationGateway, - FilesystemGateway, LiveBatchJobProgressGateway, LiveModelEndpointsSchemaGateway, ModelEndpointInfraGateway, ) -from llm_engine_server.infra.gateways.fake_model_primitive_gateway import FakeModelPrimitiveGateway -from llm_engine_server.infra.gateways.fake_monitoring_metrics_gateway import ( +from model_engine_server.infra.gateways.fake_model_primitive_gateway import ( + FakeModelPrimitiveGateway, +) +from model_engine_server.infra.gateways.fake_monitoring_metrics_gateway import ( FakeMonitoringMetricsGateway, ) -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( +from model_engine_server.infra.gateways.filesystem_gateway import FilesystemGateway +from model_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( EndpointResourceGateway, EndpointResourceGatewayCreateOrUpdateResourcesResponse, QueueInfo, ) -from llm_engine_server.infra.gateways.resources.image_cache_gateway import ( +from model_engine_server.infra.gateways.resources.image_cache_gateway import ( CachedImages, ImageCacheGateway, ) -from llm_engine_server.infra.repositories import ( +from model_engine_server.infra.repositories import ( BatchJobRecordRepository, FeatureFlagRepository, + LLMFineTuneRepository, ModelEndpointCacheRepository, ModelEndpointRecordRepository, ) -from llm_engine_server.infra.repositories.db_model_bundle_repository import ( +from model_engine_server.infra.repositories.db_model_bundle_repository import ( translate_kwargs_to_model_bundle_orm, translate_model_bundle_orm_to_model_bundle, ) -from llm_engine_server.infra.services import LiveBatchJobService, LiveModelEndpointService -from llm_engine_server.infra.services.image_cache_service import ImageCacheService -from llm_engine_server.infra.services.live_llm_model_endpoint_service import ( +from model_engine_server.infra.services import LiveBatchJobService, LiveModelEndpointService +from model_engine_server.infra.services.fake_llm_batch_completions_service import ( + FakeLLMBatchCompletionsService, +) +from model_engine_server.infra.services.image_cache_service import ImageCacheService +from model_engine_server.infra.services.live_llm_batch_completions_service import ( + LiveLLMBatchCompletionsService, +) +from model_engine_server.infra.services.live_llm_model_endpoint_service import ( LiveLLMModelEndpointService, ) +from transformers import AutoTokenizer def _translate_fake_model_endpoint_orm_to_model_endpoint_record( @@ -647,7 +680,10 @@ def get_image_url(self, image_tag: str, repository_name: str) -> str: def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: if self.raises_error: raise Exception("I hope you're handling this!") - return BuildImageResponse(status=True, logs="") + return BuildImageResponse(status=True, logs="", job_name="test-job-name") + + def get_latest_image_tag(self, repository_name: str) -> str: + return "fake_docker_repository_latest_image_tag" class FakeModelEndpointCacheRepository(ModelEndpointCacheRepository): @@ -693,9 +729,225 @@ async def read_feature_flag_bool( return self.db.get(key, None) +class FakeLLMFineTuneRepository(LLMFineTuneRepository): + def __init__(self, db: Optional[Dict[Tuple[str, str], LLMFineTuneTemplate]] = None): + self.db = db + if self.db is None: + self.db = {} + + async def get_job_template_for_model( + self, model_name: str, fine_tuning_method: str + ) -> Optional[LLMFineTuneTemplate]: + return self.db.get((model_name, fine_tuning_method), None) + + async def write_job_template_for_model( + self, + model_name: str, + fine_tuning_method: str, + job_template: LLMFineTuneTemplate, + ): + self.db[(model_name, fine_tuning_method)] = job_template + + +class FakeLLMFineTuneEventsRepository(LLMFineTuneEventsRepository): + def __init__(self): + self.initialized_events = [] + self.all_events_list = [LLMFineTuneEvent(timestamp=1, message="message", level="info")] + + async def get_fine_tune_events(self, user_id: str, model_endpoint_name: str): + if (user_id, model_endpoint_name) in self.initialized_events: + return self.all_events_list + raise ObjectNotFoundException + + async def initialize_events(self, user_id: str, model_endpoint_name: str): + self.initialized_events.append((user_id, model_endpoint_name)) + + +class FakeLLMArtifactGateway(LLMArtifactGateway): + def __init__(self): + self.existing_models = [] + self.s3_bucket = { + "fake-checkpoint": [ + "model-fake.bin, model-fake2.bin", + "model-fake.safetensors", + ], + "llama-7b/tokenizer.json": ["llama-7b/tokenizer.json"], + "llama-7b/tokenizer_config.json": ["llama-7b/tokenizer_config.json"], + "llama-7b/special_tokens_map.json": ["llama-7b/special_tokens_map.json"], + "llama-2-7b": ["model-fake.safetensors"], + "mpt-7b": ["model-fake.safetensors"], + "llama-3-70b": ["model-fake.safetensors"], + "llama-3-1-405b-instruct": ["model-fake.safetensors"], + } + self.urls = {"filename": "https://test-bucket.s3.amazonaws.com/llm/llm-1.0.0.tar.gz"} + self.model_config = { + "_name_or_path": "meta-llama/Llama-2-7b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "tie_word_embeddings": False, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "use_cache": True, + "vocab_size": 32000, + } + self.tokenizer_config = { + "add_bos_token": True, + "add_eos_token": False, + "add_prefix_space": None, + "added_tokens_decoder": { + "0": { + "content": "", + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + "special": True, + }, + "1": { + "content": "", + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + "special": True, + }, + "2": { + "content": "", + "lstrip": False, + "normalized": False, + "rstrip": False, + "single_word": False, + "special": True, + }, + }, + "additional_special_tokens": [], + "bos_token": "", + "chat_template": "{%- if messages[0]['role'] == 'system' %}\n {%- set system_message = messages[0]['content'] %}\n {%- set loop_messages = messages[1:] %}\n{%- else %}\n {%- set loop_messages = messages %}\n{%- endif %}\n\n{{- bos_token }}\n{%- for message in loop_messages %}\n {%- if (message['role'] == 'user') != (loop.index0 % 2 == 0) %}\n {{- raise_exception('After the optional system message, conversation roles must alternate user/assistant/user/assistant/...') }}\n {%- endif %}\n {%- if message['role'] == 'user' %}\n {%- if loop.first and system_message is defined %}\n {{- ' [INST] ' + system_message + '\\n\\n' + message['content'] + ' [/INST]' }}\n {%- else %}\n {{- ' [INST] ' + message['content'] + ' [/INST]' }}\n {%- endif %}\n {%- elif message['role'] == 'assistant' %}\n {{- ' ' + message['content'] + eos_token}}\n {%- else %}\n {{- raise_exception('Only user and assistant roles are supported, with the exception of an initial optional system message!') }}\n {%- endif %}\n{%- endfor %}\n", + "clean_up_tokenization_spaces": False, + "eos_token": "", + "legacy": False, + "model_max_length": 1000000000000000019884624838656, + "pad_token": None, + "sp_model_kwargs": {}, + "spaces_between_special_tokens": False, + "tokenizer_class": "LlamaTokenizer", + "unk_token": "", + "use_default_system_prompt": False, + } + + def _add_model(self, owner: str, model_name: str): + self.existing_models.append((owner, model_name)) + + def list_files(self, path: str, **kwargs) -> List[str]: + path = path.lstrip("s3://") + if path in self.s3_bucket: + return self.s3_bucket[path] + + def download_files(self, path: str, target_path: str, overwrite=False, **kwargs) -> List[str]: + path = path.lstrip("s3://") + if path in self.s3_bucket: + return self.s3_bucket[path] + + def get_model_weights_urls(self, owner: str, model_name: str): + if (owner, model_name) in self.existing_models: + return self.urls + raise ObjectNotFoundException + + def get_model_config(self, path: str, **kwargs) -> Dict[str, Any]: + return self.model_config + + +class FakeTriggerRepository(TriggerRepository): # pragma: no cover + def __init__(self, contents: Optional[Dict[str, Trigger]] = None): + self.db = {} if contents is None else contents + self.next_id = 0 + + def _get_new_id(self): + new_id = f"trig_{self.next_id}" + self.next_id += 1 + return new_id + + async def create_trigger( + self, + *, + name: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Optional[Dict[str, str]], + ) -> Trigger: + trigger_id = self._get_new_id() + trigger = Trigger( + id=trigger_id, + name=name, + owner=owner, + created_by=created_by, + created_at=datetime.now(), + cron_schedule=cron_schedule, + docker_image_batch_job_bundle_id=docker_image_batch_job_bundle_id, + default_job_config=default_job_config, + default_job_metadata=default_job_metadata, + ) + self.db[trigger_id] = trigger + return trigger + + async def list_triggers( + self, + owner: str, + ) -> Sequence[Trigger]: + def filter_fn(trig: Trigger) -> bool: + return trig.owner == owner + + return list(filter(filter_fn, self.db.values())) + + async def get_trigger( + self, + trigger_id: str, + ) -> Optional[Trigger]: + return self.db.get(trigger_id) + + async def update_trigger( + self, + trigger_id: str, + cron_schedule: str, + ) -> bool: + if trigger_id not in self.db: + return False + + self.db[trigger_id].cron_schedule = cron_schedule + return True + + async def delete_trigger( + self, + trigger_id: str, + ) -> bool: + if trigger_id not in self.db: + return False + + del self.db[trigger_id] + return True + + class FakeImageCacheGateway(ImageCacheGateway): def __init__(self): - self.cached_images = CachedImages(cpu=[], a10=[], a100=[], t4=[]) + self.cached_images = CachedImages( + cpu=[], a10=[], a100=[], t4=[], h100=[], h100_1g20gb=[], h100_3g40gb=[] + ) async def create_or_update_image_cache(self, cached_images: CachedImages) -> None: self.cached_images = cached_images @@ -830,7 +1082,8 @@ def create_model_endpoint_infra( gpus: int, memory: StorageSpecificationType, gpu_type: Optional[GpuType], - storage: Optional[StorageSpecificationType], + storage: StorageSpecificationType, + nodes_per_worker: int, optimize_costs: bool, aws_role: str, results_s3_bucket: str, @@ -839,6 +1092,7 @@ def create_model_endpoint_infra( labels: Dict[str, str], prewarm: Optional[bool], high_priority: Optional[bool], + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str], default_callback_auth: Optional[CallbackAuth], ) -> str: @@ -864,6 +1118,7 @@ def create_model_endpoint_infra( gpu_type=gpu_type, memory=memory, storage=storage, + nodes_per_worker=nodes_per_worker, optimize_costs=optimize_costs, ), user_config_state=ModelEndpointUserConfigState( @@ -937,6 +1192,7 @@ async def update_model_endpoint_infra( labels: Optional[Dict[str, str]] = None, prewarm: Optional[bool] = None, high_priority: Optional[bool] = None, + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth] = None, ) -> str: @@ -1000,7 +1256,7 @@ async def create_queue( """Creates a new, unique queue name. Used by this endpoint resource gateway to create new resources. """ - return QueueInfo(queue_name="foobar", broker=BrokerType.REDIS) + return QueueInfo(queue_name="foobar", queue_url=None) async def create_or_update_resources( self, request: CreateOrUpdateResourcesRequest @@ -1029,6 +1285,7 @@ async def create_or_update_resources( gpu_type=build_endpoint_request.gpu_type, memory=build_endpoint_request.memory, storage=build_endpoint_request.storage, + nodes_per_worker=build_endpoint_request.nodes_per_worker, optimize_costs=build_endpoint_request.optimize_costs, ), user_config_state=ModelEndpointUserConfigState( @@ -1087,8 +1344,11 @@ async def create_docker_image_batch_job( resource_requests: CreateDockerImageBatchJobResourceRequests, labels: Dict[str, str], mount_location: Optional[str], + annotations: Optional[Dict[str, str]] = None, + override_job_max_runtime_s: Optional[int] = None, + num_workers: Optional[int] = 1, ) -> str: - job_id = f"job-{self.id}" + job_id = f"ft-{self.id}" self.id += 1 self.db[job_id] = DockerImageBatchJob( @@ -1098,6 +1358,9 @@ async def create_docker_image_batch_job( created_at=datetime.now(), completed_at=None, status=BatchJobStatus.RUNNING, + annotations=annotations, + override_job_max_runtime_s=override_job_max_runtime_s, + num_workers=num_workers, ) return job_id @@ -1118,54 +1381,136 @@ async def update_docker_image_batch_job(self, batch_job_id: str, cancel: bool) - return cancel +class FakeCronJobGateway(CronJobGateway): + def __init__(self, contents=None): + self.db = contents or {} + self.suspended_cronjobs: Set[str] = set() + self.id = 0 + + async def create_cronjob( + self, + *, + request_host: str, + trigger_id: str, + created_by: str, + owner: str, + cron_schedule: str, + docker_image_batch_job_bundle_id: str, + default_job_config: Optional[Dict[str, Any]], + default_job_metadata: Dict[str, str], + ) -> None: + cron_job_id = f"cronjob-{trigger_id}" + self.id += 1 + + self.db[cron_job_id] = Trigger( + id=cron_job_id, + name=cron_job_id, + owner=owner, + created_by=created_by, + created_at=datetime.now(), + cron_schedule=cron_schedule, + docker_image_batch_job_bundle_id=docker_image_batch_job_bundle_id, + default_job_config=default_job_config, + default_job_metadata=default_job_metadata, + ) + + async def list_jobs( + self, + *, + owner: str, + trigger_id: Optional[str], + ) -> List[DockerImageBatchJob]: + return [] + + async def update_cronjob( + self, + *, + trigger_id: str, + cron_schedule: Optional[str], + suspend: Optional[bool], + ) -> None: + cron_job_id = f"cronjob-{trigger_id}" + if cron_job_id not in self.db: + return + + if cron_schedule is not None: + self.db[cron_job_id].cron_schedule = cron_schedule + if suspend is not None: + if suspend: + self.suspended_cronjobs.add(cron_job_id) + else: + self.suspended_cronjobs.discard(cron_job_id) + + async def delete_cronjob( + self, + *, + trigger_id: str, + ) -> None: + cron_job_id = f"cronjob-{trigger_id}" + self.db.pop(cron_job_id, None) + self.suspended_cronjobs.discard(cron_job_id) + + class FakeLLMFineTuningService(LLMFineTuningService): def __init__(self, contents=None): self.db: Dict[str, DockerImageBatchJob] = {} if contents is None else contents self.id = 0 - async def create_fine_tune_job( + async def create_fine_tune( self, created_by: str, owner: str, + model: str, training_file: str, - validation_file: str, - model_name: str, - base_model: str, + validation_file: Optional[str], fine_tuning_method: str, - hyperparameters: Dict[str, str], + hyperparameters: Dict[str, FineTuneHparamValueType], + fine_tuned_model: str, + wandb_config: Optional[Dict[str, Any]], ) -> str: - job_id = f"job-{self.id}" + job_id = f"ft-{self.id}" self.id += 1 + now = datetime.now() + self.db[job_id] = DockerImageBatchJob( id=job_id, created_by=created_by, owner=owner, - created_at=datetime.now(), + created_at=now, completed_at=None, status=BatchJobStatus.RUNNING, + annotations={ + "fine_tuned_model": fine_tuned_model, + }, ) return job_id - async def get_fine_tune_job( - self, owner: str, fine_tune_id: str - ) -> Optional[DockerImageBatchJob]: + async def get_fine_tune(self, owner: str, fine_tune_id: str) -> Optional[DockerImageBatchJob]: di_batch_job = self.db.get(fine_tune_id) - if di_batch_job is None or di_batch_job["owner"] != owner: + if di_batch_job is None or di_batch_job.owner != owner: return None return di_batch_job - async def list_fine_tune_jobs(self, owner: str) -> List[DockerImageBatchJob]: - return [job for job in self.db.values() if job["owner"] == owner] + async def list_fine_tunes(self, owner: str) -> List[DockerImageBatchJob]: + return [job for job in self.db.values() if job.owner == owner] - async def cancel_fine_tune_job(self, owner: str, fine_tune_id: str) -> bool: - if fine_tune_id not in self.db or self.db.get(fine_tune_id)["owner"] != owner: + async def cancel_fine_tune(self, owner: str, fine_tune_id: str) -> bool: + if fine_tune_id not in self.db or self.db.get(fine_tune_id).owner != owner: return False del self.db[fine_tune_id] return True + async def get_fine_tune_model_name_from_id( + self, owner: str, fine_tune_id: str + ) -> Optional[str]: + fine_tune = self.db.get(fine_tune_id, None) + if fine_tune is not None and fine_tune.owner == owner: + return fine_tune.annotations["fine_tuned_model"] + return None + class FakeStreamingModelEndpointInferenceGateway(StreamingModelEndpointInferenceGateway): def __init__(self): @@ -1178,7 +1523,11 @@ def __init__(self): ] async def streaming_predict( - self, topic: str, predict_request: EndpointPredictV1Request + self, + topic: str, + predict_request: EndpointPredictV1Request, + manually_resolve_dns: bool = False, + endpoint_name: Optional[str] = None, ) -> AsyncIterable[SyncEndpointPredictV1Response]: """ Runs a prediction request and returns a response. @@ -1188,15 +1537,22 @@ async def streaming_predict( class FakeSyncModelEndpointInferenceGateway(SyncModelEndpointInferenceGateway): - def __init__(self): - self.response = SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result=None, - traceback=None, - ) + def __init__(self, fake_sync_inference_content=None): + if not fake_sync_inference_content: + self.response = SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result=None, + traceback=None, + ) + else: + self.response = fake_sync_inference_content async def predict( - self, topic: str, predict_request: EndpointPredictV1Request + self, + topic: str, + predict_request: EndpointPredictV1Request, + manually_resolve_dns: bool = False, + endpoint_name: Optional[str] = None, ) -> SyncEndpointPredictV1Response: """ Runs a prediction request and returns a response. @@ -1204,6 +1560,52 @@ async def predict( return self.response +class FakeFileStorageGateway(FileStorageGateway): + def __init__(self, contents=None): + self.db: Dict[str, FileMetadata] = {} if contents is None else contents + self.id = 0 + self.content = "Test content" + + async def get_url_from_id(self, owner: str, file_id: str) -> Optional[str]: + return "dummy URL" + + async def upload_file(self, owner: str, filename: str, content: bytes) -> str: + file_id = f"file-{self.id}" + self.id += 1 + + self.db[file_id] = FileMetadata( + id=file_id, + filename=f"{file_id}_name", + size=len(self.content), + owner=owner, + updated_at=datetime.now(), + ) + + return file_id + + async def get_file(self, owner: str, file_id: str) -> Optional[FileMetadata]: + file = self.db.get(file_id) + if file is None or file.owner != owner: + return None + return file + + async def list_files(self, owner: str) -> List[FileMetadata]: + return [file for file in self.db.values() if file.owner == owner] + + async def delete_file(self, owner: str, file_id: str) -> bool: + if file_id not in self.db or self.db.get(file_id).owner != owner: + return False + + del self.db[file_id] + return True + + async def get_file_content(self, owner: str, file_id: str) -> Optional[str]: + file = self.db.get(file_id) + if file is None or file.owner != owner: + return None + return self.content + + @dataclass class FakeAsyncTask: topic: str @@ -1250,6 +1652,25 @@ def get_last_request(self): return self.tasks[-1] +class FakeInferenceAutoscalingMetricsGateway(InferenceAutoscalingMetricsGateway): + async def emit_inference_autoscaling_metric(self, endpoint_id: str): + pass + + async def emit_prewarm_metric(self, endpoint_id: str): + pass + + async def create_or_update_resources(self, endpoint_id: str): + pass + + async def delete_resources(self, endpoint_id: str): + pass + + +class FakeStreamingStorageGateway(StreamingStorageGateway): + def put_record(self, stream_name: str, record: Dict[str, Any]): + pass + + class FakeModelEndpointService(ModelEndpointService): db: Dict[str, ModelEndpoint] @@ -1262,6 +1683,8 @@ def __init__( StreamingModelEndpointInferenceGateway ] = None, sync_model_endpoint_inference_gateway: Optional[SyncModelEndpointInferenceGateway] = None, + inference_autoscaling_metrics_gateway: Optional[InferenceAutoscalingMetricsGateway] = None, + can_scale_http_endpoint_from_zero_flag: bool = True, ): if contents: self.db = contents @@ -1291,10 +1714,17 @@ def __init__( if sync_model_endpoint_inference_gateway is None: sync_model_endpoint_inference_gateway = FakeSyncModelEndpointInferenceGateway() self.sync_model_endpoint_inference_gateway = sync_model_endpoint_inference_gateway + + if inference_autoscaling_metrics_gateway is None: + inference_autoscaling_metrics_gateway = FakeInferenceAutoscalingMetricsGateway() + self.inference_autoscaling_metrics_gateway = inference_autoscaling_metrics_gateway + self.model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=FakeFilesystemGateway() ) + self.can_scale_http_endpoint_from_zero_flag = can_scale_http_endpoint_from_zero_flag + def get_async_model_endpoint_inference_gateway( self, ) -> AsyncModelEndpointInferenceGateway: @@ -1310,6 +1740,11 @@ def get_sync_model_endpoint_inference_gateway( ) -> SyncModelEndpointInferenceGateway: return self.sync_model_endpoint_inference_gateway + def get_inference_autoscaling_metrics_gateway( + self, + ) -> InferenceAutoscalingMetricsGateway: + return self.inference_autoscaling_metrics_gateway + def add_model_endpoint(self, model_endpoint: ModelEndpoint): self.db[model_endpoint.record.id] = model_endpoint @@ -1327,7 +1762,8 @@ async def create_model_endpoint( gpus: int, memory: StorageSpecificationType, gpu_type: Optional[GpuType], - storage: Optional[StorageSpecificationType], + storage: StorageSpecificationType, + nodes_per_worker: int, optimize_costs: bool, min_workers: int, max_workers: int, @@ -1337,6 +1773,7 @@ async def create_model_endpoint( results_s3_bucket: str, prewarm: Optional[bool], high_priority: Optional[bool], + billing_tags: Optional[Dict[str, Any]] = None, owner: str, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth] = None, @@ -1385,6 +1822,7 @@ async def create_model_endpoint( memory=memory, gpu_type=gpu_type, storage=storage, + nodes_per_worker=nodes_per_worker, optimize_costs=optimize_costs, ), user_config_state=ModelEndpointUserConfigState( @@ -1393,7 +1831,9 @@ async def create_model_endpoint( bundle_name=current_model_bundle.name, endpoint_name=name, post_inference_hooks=post_inference_hooks, + billing_tags=billing_tags, user_id=created_by, + billing_queue="some:arn:for:something", default_callback_url=default_callback_url, default_callback_auth=default_callback_auth, ), @@ -1424,6 +1864,7 @@ async def update_model_endpoint( results_s3_bucket: Optional[str] = None, prewarm: Optional[bool] = None, high_priority: Optional[bool] = None, + billing_tags: Optional[Dict[str, Any]] = None, default_callback_url: Optional[str] = None, default_callback_auth: Optional[CallbackAuth] = None, public_inference: Optional[bool] = None, @@ -1501,6 +1942,17 @@ async def delete_model_endpoint(self, model_endpoint_id: str) -> None: raise ObjectNotFoundException del self.db[model_endpoint_id] + def set_can_scale_http_endpoint_from_zero_flag(self, flag: bool): + self.can_scale_http_endpoint_from_zero_flag = flag + + def can_scale_http_endpoint_from_zero(self) -> bool: + return self.can_scale_http_endpoint_from_zero_flag + + +class FakeTokenizerRepository(TokenizerRepository): + def load_tokenizer(self, model_name: str) -> AutoTokenizer: + return AutoTokenizer.from_pretrained(model_name) + class FakeLLMModelEndpointService(LLMModelEndpointService): db: Dict[str, ModelEndpoint] @@ -1621,6 +2073,23 @@ def fake_docker_image_batch_job_bundle_repository() -> FakeDockerImageBatchJobBu return repo +@pytest.fixture +def fake_llm_fine_tune_repository() -> FakeLLMFineTuneRepository: + repo = FakeLLMFineTuneRepository() + return repo + + +@pytest.fixture +def fake_llm_fine_tuning_events_repository() -> FakeLLMFineTuneEventsRepository: + repo = FakeLLMFineTuneEventsRepository() + return repo + + +def fake_trigger_repository() -> FakeTriggerRepository: + repo = FakeTriggerRepository() + return repo + + @pytest.fixture def fake_image_cache_gateway() -> FakeImageCacheGateway: gateway = FakeImageCacheGateway() @@ -1639,6 +2108,18 @@ def fake_batch_job_orchestration_gateway() -> FakeBatchJobOrchestrationGateway: return gateway +@pytest.fixture +def fake_docker_image_batch_job_gateway() -> FakeDockerImageBatchJobGateway: + gateway = FakeDockerImageBatchJobGateway() + return gateway + + +@pytest.fixture +def fake_llm_batch_completions_service() -> FakeLLMBatchCompletionsService: + service = FakeLLMBatchCompletionsService() + return service + + @pytest.fixture def fake_monitoring_metrics_gateway() -> FakeMonitoringMetricsGateway: gateway = FakeMonitoringMetricsGateway() @@ -1693,6 +2174,29 @@ def fake_sync_model_endpoint_inference_gateway() -> FakeSyncModelEndpointInferen return gateway +@pytest.fixture +def fake_inference_autoscaling_metrics_gateway() -> FakeInferenceAutoscalingMetricsGateway: + gateway = FakeInferenceAutoscalingMetricsGateway() + return gateway + + +@pytest.fixture +def fake_file_storage_gateway() -> FakeFileStorageGateway: + gateway = FakeFileStorageGateway() + return gateway + + +@pytest.fixture +def fake_llm_artifact_gateway() -> FakeLLMArtifactGateway: + gateway = FakeLLMArtifactGateway() + return gateway + + +def fake_cron_job_gateway() -> FakeCronJobGateway: + gateway = FakeCronJobGateway() + return gateway + + @pytest.fixture def fake_model_endpoint_service() -> FakeModelEndpointService: service = FakeModelEndpointService() @@ -1705,6 +2209,12 @@ def fake_llm_model_endpoint_service() -> FakeLLMModelEndpointService: return service +@pytest.fixture +def fake_llm_fine_tuning_service() -> FakeLLMFineTuningService: + service = FakeLLMFineTuningService() + return service + + @pytest.fixture def fake_image_cache_service( fake_image_cache_gateway, @@ -1718,6 +2228,17 @@ def fake_image_cache_service( ) +@pytest.fixture +def fake_tokenizer_repository() -> TokenizerRepository: + return FakeTokenizerRepository() + + +@pytest.fixture +def fake_streaming_storage_gateway() -> StreamingStorageGateway: + gateway = FakeStreamingStorageGateway() + return gateway + + @pytest.fixture def get_repositories_generator_wrapper(): def get_repositories_generator( @@ -1727,14 +2248,21 @@ def get_repositories_generator( fake_model_endpoint_infra_gateway_contents, fake_batch_job_record_repository_contents, fake_batch_job_progress_gateway_contents, + fake_cron_job_gateway_contents, fake_docker_image_batch_job_bundle_repository_contents, fake_docker_image_batch_job_gateway_contents, fake_llm_fine_tuning_service_contents, + fake_file_storage_gateway_contents, + fake_trigger_repository_contents, + fake_file_system_gateway_contents, + fake_sync_inference_content, ): def get_test_repositories() -> Iterator[ExternalInterfaces]: + fake_file_system_gateway = FakeFilesystemGateway() fake_model_bundle_repository = FakeModelBundleRepository( contents=fake_model_bundle_repository_contents ) + fake_monitoring_metrics_gateway = FakeMonitoringMetricsGateway() fake_model_endpoint_record_repository = FakeModelEndpointRecordRepository( contents=fake_model_endpoint_record_repository_contents, model_bundle_repository=fake_model_bundle_repository, @@ -1748,7 +2276,10 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: streaming_model_endpoint_inference_gateway = ( FakeStreamingModelEndpointInferenceGateway() ) - sync_model_endpoint_inference_gateway = FakeSyncModelEndpointInferenceGateway() + sync_model_endpoint_inference_gateway = FakeSyncModelEndpointInferenceGateway( + fake_sync_inference_content + ) + inference_autoscaling_metrics_gateway = FakeInferenceAutoscalingMetricsGateway() model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( filesystem_gateway=FakeFilesystemGateway(), ) @@ -1759,7 +2290,9 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: async_model_endpoint_inference_gateway=async_model_endpoint_inference_gateway, streaming_model_endpoint_inference_gateway=streaming_model_endpoint_inference_gateway, sync_model_endpoint_inference_gateway=sync_model_endpoint_inference_gateway, + inference_autoscaling_metrics_gateway=inference_autoscaling_metrics_gateway, model_endpoints_schema_gateway=model_endpoints_schema_gateway, + can_scale_http_endpoint_from_zero_flag=True, # reasonable default, gets overridden in individual tests if needed ) fake_batch_job_service = LiveBatchJobService( batch_job_record_repository=FakeBatchJobRecordRepository( @@ -1777,16 +2310,29 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: fake_docker_image_batch_job_bundle_repository = FakeDockerImageBatchJobBundleRepository( contents=fake_docker_image_batch_job_bundle_repository_contents ) + fake_trigger_repository = FakeTriggerRepository( + contents=fake_trigger_repository_contents + ) fake_docker_image_batch_job_gateway = FakeDockerImageBatchJobGateway( fake_docker_image_batch_job_gateway_contents ) + fake_llm_artifact_gateway = FakeLLMArtifactGateway() + fake_cron_job_gateway = FakeCronJobGateway(fake_cron_job_gateway_contents) fake_llm_model_endpoint_service = LiveLLMModelEndpointService( model_endpoint_record_repository=fake_model_endpoint_record_repository, model_endpoint_service=fake_model_endpoint_service, ) + fake_llm_batch_completions_service = LiveLLMBatchCompletionsService( + docker_image_batch_job_gateway=fake_docker_image_batch_job_gateway + ) fake_llm_fine_tuning_service = FakeLLMFineTuningService( fake_llm_fine_tuning_service_contents ) + fake_llm_fine_tuning_events_repository = FakeLLMFineTuneEventsRepository() + fake_file_storage_gateway = FakeFileStorageGateway(fake_file_storage_gateway_contents) + fake_tokenizer_repository = FakeTokenizerRepository() + fake_streaming_storage_gateway = FakeStreamingStorageGateway() + repositories = ExternalInterfaces( docker_repository=FakeDockerRepository( fake_docker_repository_image_always_exists, False @@ -1794,6 +2340,7 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: model_bundle_repository=fake_model_bundle_repository, model_endpoint_service=fake_model_endpoint_service, llm_model_endpoint_service=fake_llm_model_endpoint_service, + llm_batch_completions_service=fake_llm_batch_completions_service, batch_job_service=fake_batch_job_service, resource_gateway=FakeEndpointResourceGateway(), endpoint_creation_task_queue_gateway=FakeTaskQueueGateway(), @@ -1803,6 +2350,15 @@ def get_test_repositories() -> Iterator[ExternalInterfaces]: docker_image_batch_job_bundle_repository=fake_docker_image_batch_job_bundle_repository, docker_image_batch_job_gateway=fake_docker_image_batch_job_gateway, llm_fine_tuning_service=fake_llm_fine_tuning_service, + llm_fine_tune_events_repository=fake_llm_fine_tuning_events_repository, + file_storage_gateway=fake_file_storage_gateway, + trigger_repository=fake_trigger_repository, + cron_job_gateway=fake_cron_job_gateway, + filesystem_gateway=fake_file_system_gateway, + llm_artifact_gateway=fake_llm_artifact_gateway, + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, + tokenizer_repository=fake_tokenizer_repository, + streaming_storage_gateway=fake_streaming_storage_gateway, ) try: yield repositories @@ -1985,7 +2541,7 @@ def model_bundle_4(test_api_key: str) -> ModelBundle: ecr_repo="test_repo", image_tag="test_tag", ), - packaging_type=ModelBundlePackagingType.CLOUDPICKLE, + packaging_type=ModelBundlePackagingType.LIRA, app_config=None, ) return model_bundle @@ -2022,7 +2578,7 @@ def model_bundle_5(test_api_key: str) -> ModelBundle: ecr_repo="test_repo", image_tag="test_tag", ), - packaging_type=ModelBundlePackagingType.CLOUDPICKLE, + packaging_type=ModelBundlePackagingType.LIRA, app_config=None, ) return model_bundle @@ -2063,7 +2619,7 @@ def model_bundle_6(test_api_key: str) -> ModelBundle: ecr_repo="test_repo", image_tag="test_tag", ), - packaging_type=ModelBundlePackagingType.CLOUDPICKLE, + packaging_type=ModelBundlePackagingType.LIRA, app_config=None, ) return model_bundle @@ -2108,7 +2664,7 @@ def model_bundle_triton_enhanced_runnable_image_0_cpu_None_memory_storage( ecr_repo="test_repo", image_tag="test_tag", ), - packaging_type=ModelBundlePackagingType.CLOUDPICKLE, + packaging_type=ModelBundlePackagingType.LIRA, app_config=None, ) return model_bundle @@ -2152,6 +2708,7 @@ def model_endpoint_1(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -2160,6 +2717,16 @@ def model_endpoint_1(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd bundle_name=model_bundle_1.name, endpoint_name="test_model_endpoint_name_1", post_inference_hooks=None, + billing_tags={ + "idempotencyKeyPrefix": "value1", + "product": "value2", + "type": "hi", + "subType": "hi", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "hi", + "payor": "hi", + "reference": {"referenceType": "hi", "referenceId": "hi"}, + }, ), ), image="000000000000.dkr.ecr.us-west-2.amazonaws.com/non-existent-repo:fake-tag", @@ -2207,6 +2774,7 @@ def model_endpoint_2(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -2217,7 +2785,7 @@ def model_endpoint_2(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd post_inference_hooks=None, ), ), - image="000000000000.dkr.ecr.us-west-2.amazonaws.com/catalog-gpu:40d3b5fb06d1a8c3d14903390a3b23ae388bdb19", + image="000000000000.dkr.ecr.us-west-2.amazonaws.com/my-repo:abcdefg222", ), ) return model_endpoint @@ -2262,6 +2830,7 @@ def model_endpoint_3(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -2272,7 +2841,7 @@ def model_endpoint_3(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd post_inference_hooks=None, ), ), - image="000000000000.dkr.ecr.us-west-2.amazonaws.com/catalog-gpu:e4ea48ddccfb9ca3ef6d846ae9b2d146d7e30b0f", + image="000000000000.dkr.ecr.us-west-2.amazonaws.com/my-repo:abcdefg111111111", ), ) return model_endpoint @@ -2317,6 +2886,7 @@ def model_endpoint_4(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -2327,7 +2897,7 @@ def model_endpoint_4(test_api_key: str, model_bundle_1: ModelBundle) -> ModelEnd post_inference_hooks=None, ), ), - image="000000000000.dkr.ecr.us-west-2.amazonaws.com/catalog-gpu:9a319cd9b897f02291f3242b1395f2b669993cdf-fd", + image="000000000000.dkr.ecr.us-west-2.amazonaws.com/my-repo:abcdefg00000", ), ) return model_endpoint @@ -2372,6 +2942,7 @@ def model_endpoint_public(test_api_key: str, model_bundle_1: ModelBundle) -> Mod memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -2380,6 +2951,16 @@ def model_endpoint_public(test_api_key: str, model_bundle_1: ModelBundle) -> Mod bundle_name=model_bundle_1.name, endpoint_name="test_model_endpoint_name_1", post_inference_hooks=None, + billing_tags={ + "idempotencyKeyPrefix": "value1", + "product": "value2", + "type": "hi", + "subType": "hi", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "hi", + "payor": "hi", + "reference": {"referenceType": "hi", "referenceId": "hi"}, + }, ), ), image="000000000000.dkr.ecr.us-west-2.amazonaws.com/non-existent-repo:fake-tag", @@ -2427,6 +3008,7 @@ def model_endpoint_public_sync(test_api_key: str, model_bundle_1: ModelBundle) - memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -2435,6 +3017,16 @@ def model_endpoint_public_sync(test_api_key: str, model_bundle_1: ModelBundle) - bundle_name=model_bundle_1.name, endpoint_name="test_model_endpoint_name_1", post_inference_hooks=None, + billing_tags={ + "idempotencyKeyPrefix": "value1", + "product": "value2", + "type": "hi", + "subType": "hi", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "hi", + "payor": "hi", + "reference": {"referenceType": "hi", "referenceId": "hi"}, + }, ), ), image="000000000000.dkr.ecr.us-west-2.amazonaws.com/non-existent-repo:fake-tag", @@ -2483,6 +3075,7 @@ def model_endpoint_runnable(test_api_key: str, model_bundle_4: ModelBundle) -> M memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -2539,6 +3132,7 @@ def model_endpoint_streaming(test_api_key: str, model_bundle_5: ModelBundle) -> memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -2705,11 +3299,12 @@ def build_endpoint_request_async_runnable_image( memory="3G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -2748,11 +3343,12 @@ def build_endpoint_request_streaming_runnable_image( memory="4G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -2791,11 +3387,12 @@ def build_endpoint_request_sync_runnable_image( memory="4G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -2834,11 +3431,12 @@ def build_endpoint_request_sync_pytorch( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, broker_type=BrokerType.SQS, default_callback_url="https://example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -2877,10 +3475,11 @@ def build_endpoint_request_async_tensorflow( memory="1G", gpu_type=None, storage=None, + nodes_per_worker=1, optimize_costs=False, default_callback_url="https://example.com/path", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth(kind="basic", username="username", password="password") + root=CallbackBasicAuth(kind="basic", username="username", password="password") ), ) return build_endpoint_request @@ -2919,6 +3518,50 @@ def build_endpoint_request_async_custom( memory="1G", gpu_type=None, storage=None, + nodes_per_worker=1, + optimize_costs=True, + broker_type=BrokerType.SQS, + default_callback_url=None, + default_callback_auth=None, + ) + return build_endpoint_request + + +@pytest.fixture +def build_endpoint_request_async_zipartifact_highpri( + test_api_key: str, model_bundle_3: ModelBundle +) -> BuildEndpointRequest: + build_endpoint_request = BuildEndpointRequest( + model_endpoint_record=ModelEndpointRecord( + id="test_model_endpoint_id_3", + name="test_model_endpoint_name_3", + created_by=test_api_key, + created_at=datetime(2022, 1, 4), + last_updated_at=datetime(2022, 1, 4), + metadata={}, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.ASYNC, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_3, + owner=test_api_key, + ), + high_priority=True, + deployment_name=f"{test_api_key}-test_model_endpoint_name_3", + aws_role="default", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + post_inference_hooks=None, + labels=dict(team="test_team", product="test_product"), + min_workers=1, + max_workers=3, + per_worker=2, + cpus=1, + gpus=0, + memory="1G", + gpu_type=None, + storage=None, + nodes_per_worker=1, optimize_costs=True, broker_type=BrokerType.SQS, default_callback_url=None, @@ -2960,6 +3603,7 @@ def build_endpoint_request_sync_custom( memory="1G", gpu_type=None, storage=None, + nodes_per_worker=1, optimize_costs=True, default_callback_url=None, default_callback_auth=None, @@ -2983,9 +3627,7 @@ def endpoint_predict_request_2() -> Tuple[EndpointPredictV1Request, Dict[str, An args=["test_arg_1", "test_arg_2"], callback_url="http://test_callback_url.xyz", callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( - kind="basic", username="test_username", password="test_password" - ) + root=CallbackBasicAuth(kind="basic", username="test_username", password="test_password") ), return_pickled=True, ) @@ -2994,12 +3636,426 @@ def endpoint_predict_request_2() -> Tuple[EndpointPredictV1Request, Dict[str, An @pytest.fixture -def llm_model_endpoint_async( - test_api_key: str, model_bundle_1: ModelBundle -) -> Tuple[ModelEndpoint, Any]: - model_endpoint = ModelEndpoint( - record=ModelEndpointRecord( - id="test_llm_model_endpoint_id_1", +def sync_endpoint_predict_request_1() -> Tuple[SyncEndpointPredictV1Request, Dict[str, Any]]: + request = SyncEndpointPredictV1Request( + url="test_url", + return_pickled=False, + timeout_seconds=10, + num_retries=5, + ) + request_dict = request.dict() + return request, request_dict + + +@pytest.fixture +def llm_model_endpoint_async( + test_api_key: str, model_bundle_1: ModelBundle +) -> Tuple[ModelEndpoint, Any]: + model_endpoint = ModelEndpoint( + record=ModelEndpointRecord( + id="test_llm_model_endpoint_id_1", + name="test_llm_model_endpoint_name_1", + created_by=test_api_key, + created_at=datetime(2022, 1, 3), + last_updated_at=datetime(2022, 1, 3), + metadata={ + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "deepspeed", + "inference_framework_image_tag": "123", + "num_shards": 4, + } + }, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.ASYNC, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_1, + owner=test_api_key, + public_inference=True, + ), + infra_state=ModelEndpointInfraState( + deployment_name=f"{test_api_key}-test_llm_model_endpoint_name_1", + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + labels={}, + prewarm=True, + high_priority=False, + deployment_state=ModelEndpointDeploymentState( + min_workers=1, + max_workers=3, + per_worker=2, + available_workers=1, + unavailable_workers=1, + ), + resource_state=ModelEndpointResourceState( + cpus=1, + gpus=1, + memory="1G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + optimize_costs=True, + ), + user_config_state=ModelEndpointUserConfigState( + app_config=model_bundle_1.app_config, + endpoint_config=ModelEndpointConfig( + bundle_name=model_bundle_1.name, + endpoint_name="test_llm_model_endpoint_name_1", + post_inference_hooks=["callback"], + default_callback_url="http://www.example.com", + default_callback_auth=CallbackAuth( + root=CallbackBasicAuth( + kind="basic", + username="test_username", + password="test_password", + ), + ), + ), + ), + num_queued_items=1, + image="test_image", + ), + ) + model_endpoint_json: Dict[str, Any] = { + "id": "test_llm_model_endpoint_id_1", + "name": "test_llm_model_endpoint_name_1", + "model_name": "llama-7b", + "source": "hugging_face", + "status": "READY", + "inference_framework": "deepspeed", + "inference_framework_image_tag": "123", + "num_shards": 4, + "spec": { + "id": "test_llm_model_endpoint_id_1", + "name": "test_llm_model_endpoint_name_1", + "endpoint_type": "async", + "destination": "test_destination", + "deployment_name": f"{test_api_key}-test_llm_model_endpoint_name_1", + "metadata": { + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "deepspeed", + "inference_framework_image_tag": "123", + "num_shards": 4, + } + }, + "bundle_name": "test_model_bundle_name_1", + "status": "READY", + "post_inference_hooks": ["callback"], + "default_callback_url": "http://www.example.com", + "default_callback_auth": { + "kind": "basic", + "username": "test_username", + "password": "test_password", + }, + "labels": {}, + "aws_role": "test_aws_role", + "results_s3_bucket": "test_s3_bucket", + "created_by": test_api_key, + "created_at": "2022-01-03T00:00:00", + "last_updated_at": "2022-01-03T00:00:00", + "deployment_state": { + "min_workers": 1, + "max_workers": 3, + "per_worker": 2, + "available_workers": 1, + "unavailable_workers": 1, + }, + "resource_state": { + "cpus": 1, + "gpus": 1, + "memory": "1G", + "gpu_type": "nvidia-tesla-t4", + "storage": "10G", + "nodes_per_worker": 1, + "optimize_costs": True, + }, + "num_queued_items": 1, + "public_inference": True, + }, + } + return model_endpoint, model_endpoint_json + + +@pytest.fixture +def llm_model_endpoint_sync( + test_api_key: str, model_bundle_1: ModelBundle +) -> Tuple[ModelEndpoint, Any]: + model_endpoint = ModelEndpoint( + record=ModelEndpointRecord( + id="test_llm_model_endpoint_id_2", + name="test_llm_model_endpoint_name_1", + created_by=test_api_key, + created_at=datetime(2022, 1, 3), + last_updated_at=datetime(2022, 1, 3), + metadata={ + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "123", + "num_shards": 4, + } + }, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.SYNC, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_1, + owner=test_api_key, + public_inference=True, + ), + infra_state=ModelEndpointInfraState( + deployment_name=f"{test_api_key}-test_llm_model_endpoint_name_1", + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + labels={}, + prewarm=True, + high_priority=False, + deployment_state=ModelEndpointDeploymentState( + min_workers=1, + max_workers=3, + per_worker=2, + available_workers=1, + unavailable_workers=1, + ), + resource_state=ModelEndpointResourceState( + cpus=1, + gpus=1, + memory="1G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + optimize_costs=True, + ), + user_config_state=ModelEndpointUserConfigState( + app_config=model_bundle_1.app_config, + endpoint_config=ModelEndpointConfig( + bundle_name=model_bundle_1.name, + endpoint_name="test_llm_model_endpoint_name_1", + post_inference_hooks=["callback"], + default_callback_url="http://www.example.com", + default_callback_auth=CallbackAuth( + root=CallbackBasicAuth( + kind="basic", + username="test_username", + password="test_password", + ), + ), + ), + ), + num_queued_items=1, + image="test_image", + ), + ) + model_endpoint_json: Dict[str, Any] = { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "model_name": "llama-7b", + "source": "hugging_face", + "status": "READY", + "inference_framework": "vllm", + "inference_framework_image_tag": "123", + "num_shards": 4, + "spec": { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "endpoint_type": "sync", + "destination": "test_destination", + "deployment_name": f"{test_api_key}-test_llm_model_endpoint_name_1", + "metadata": { + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "123", + "num_shards": 4, + } + }, + "bundle_name": "test_model_bundle_name_1", + "status": "READY", + "post_inference_hooks": ["callback"], + "default_callback_url": "http://www.example.com", + "default_callback_auth": { + "kind": "basic", + "username": "test_username", + "password": "test_password", + }, + "labels": {}, + "aws_role": "test_aws_role", + "results_s3_bucket": "test_s3_bucket", + "created_by": test_api_key, + "created_at": "2022-01-03T00:00:00", + "last_updated_at": "2022-01-03T00:00:00", + "deployment_state": { + "min_workers": 1, + "max_workers": 3, + "per_worker": 2, + "available_workers": 1, + "unavailable_workers": 1, + }, + "resource_state": { + "cpus": 1, + "gpus": 1, + "memory": "1G", + "gpu_type": "nvidia-tesla-t4", + "storage": "10G", + "nodes_per_worker": 1, + "optimize_costs": True, + }, + "num_queued_items": 1, + "public_inference": True, + }, + } + return model_endpoint, model_endpoint_json + + +@pytest.fixture +def llm_model_endpoint_stream( + test_api_key: str, model_bundle_1: ModelBundle +) -> Tuple[ModelEndpoint, Any]: + model_endpoint = ModelEndpoint( + record=ModelEndpointRecord( + id="test_llm_model_endpoint_id_2", + name="test_llm_model_endpoint_name_1", + created_by=test_api_key, + created_at=datetime(2022, 1, 3), + last_updated_at=datetime(2022, 1, 3), + metadata={ + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "123", + "num_shards": 4, + } + }, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.STREAMING, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_1, + owner=test_api_key, + public_inference=True, + ), + infra_state=ModelEndpointInfraState( + deployment_name=f"{test_api_key}-test_llm_model_endpoint_name_1", + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + labels={}, + prewarm=True, + high_priority=False, + deployment_state=ModelEndpointDeploymentState( + min_workers=1, + max_workers=3, + per_worker=2, + available_workers=1, + unavailable_workers=1, + ), + resource_state=ModelEndpointResourceState( + cpus=1, + gpus=1, + memory="1G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + optimize_costs=True, + ), + user_config_state=ModelEndpointUserConfigState( + app_config=model_bundle_1.app_config, + endpoint_config=ModelEndpointConfig( + bundle_name=model_bundle_1.name, + endpoint_name="test_llm_model_endpoint_name_1", + post_inference_hooks=["callback"], + default_callback_url="http://www.example.com", + default_callback_auth=CallbackAuth( + root=CallbackBasicAuth( + kind="basic", + username="test_username", + password="test_password", + ), + ), + ), + ), + num_queued_items=1, + image="test_image", + ), + ) + model_endpoint_json: Dict[str, Any] = { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "model_name": "llama-7b", + "source": "hugging_face", + "status": "READY", + "inference_framework": "vllm", + "inference_framework_image_tag": "123", + "num_shards": 4, + "spec": { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "endpoint_type": "streaming", + "destination": "test_destination", + "deployment_name": f"{test_api_key}-test_llm_model_endpoint_name_1", + "metadata": { + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "vllm", + "inference_framework_image_tag": "123", + "num_shards": 4, + } + }, + "bundle_name": "test_model_bundle_name_1", + "status": "READY", + "post_inference_hooks": ["callback"], + "default_callback_url": "http://www.example.com", + "default_callback_auth": { + "kind": "basic", + "username": "test_username", + "password": "test_password", + }, + "labels": {}, + "aws_role": "test_aws_role", + "results_s3_bucket": "test_s3_bucket", + "created_by": test_api_key, + "created_at": "2022-01-03T00:00:00", + "last_updated_at": "2022-01-03T00:00:00", + "deployment_state": { + "min_workers": 1, + "max_workers": 3, + "per_worker": 2, + "available_workers": 1, + "unavailable_workers": 1, + }, + "resource_state": { + "cpus": 1, + "gpus": 1, + "memory": "1G", + "gpu_type": "nvidia-tesla-t4", + "storage": "10G", + "nodes_per_worker": 1, + "optimize_costs": True, + }, + "num_queued_items": 1, + "public_inference": True, + }, + } + return model_endpoint, model_endpoint_json + + +@pytest.fixture +def llm_model_endpoint_sync_tgi( + test_api_key: str, model_bundle_1: ModelBundle +) -> Tuple[ModelEndpoint, Any]: + model_endpoint = ModelEndpoint( + record=ModelEndpointRecord( + id="test_llm_model_endpoint_id_2", name="test_llm_model_endpoint_name_1", created_by=test_api_key, created_at=datetime(2022, 1, 3), @@ -3008,13 +4064,13 @@ def llm_model_endpoint_async( "_llm": { "model_name": "llama-7b", "source": "hugging_face", - "inference_framework": "deepspeed", - "inference_framework_image_tag": "123", + "inference_framework": "text_generation_inference", + "inference_framework_image_tag": "0.9.4", "num_shards": 4, } }, creation_task_id="test_creation_task_id", - endpoint_type=ModelEndpointType.ASYNC, + endpoint_type=ModelEndpointType.SYNC, destination="test_destination", status=ModelEndpointStatus.READY, current_model_bundle=model_bundle_1, @@ -3042,6 +4098,7 @@ def llm_model_endpoint_async( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -3052,7 +4109,7 @@ def llm_model_endpoint_async( post_inference_hooks=["callback"], default_callback_url="http://www.example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( + root=CallbackBasicAuth( kind="basic", username="test_username", password="test_password", @@ -3065,25 +4122,26 @@ def llm_model_endpoint_async( ), ) model_endpoint_json: Dict[str, Any] = { - "id": "test_llm_model_endpoint_id_1", + "id": "test_llm_model_endpoint_id_2", "name": "test_llm_model_endpoint_name_1", "model_name": "llama-7b", "source": "hugging_face", - "inference_framework": "deepspeed", - "inference_framework_image_tag": "123", + "status": "READY", + "inference_framework": "text_generation_inference", + "inference_framework_image_tag": "0.9.4", "num_shards": 4, "spec": { - "id": "test_llm_model_endpoint_id_1", + "id": "test_llm_model_endpoint_id_2", "name": "test_llm_model_endpoint_name_1", - "endpoint_type": "async", + "endpoint_type": "sync", "destination": "test_destination", "deployment_name": f"{test_api_key}-test_llm_model_endpoint_name_1", "metadata": { "_llm": { "model_name": "llama-7b", "source": "hugging_face", - "inference_framework": "deepspeed", - "inference_framework_image_tag": "123", + "inference_framework": "text_generation_inference", + "inference_framework_image_tag": "0.9.4", "num_shards": 4, } }, @@ -3110,11 +4168,12 @@ def llm_model_endpoint_async( "unavailable_workers": 1, }, "resource_state": { - "cpus": "1", + "cpus": 1, "gpus": 1, "memory": "1G", "gpu_type": "nvidia-tesla-t4", "storage": "10G", + "nodes_per_worker": 1, "optimize_costs": True, }, "num_queued_items": 1, @@ -3125,7 +4184,7 @@ def llm_model_endpoint_async( @pytest.fixture -def llm_model_endpoint_sync( +def llm_model_endpoint_sync_lightllm( test_api_key: str, model_bundle_1: ModelBundle ) -> Tuple[ModelEndpoint, Any]: model_endpoint = ModelEndpoint( @@ -3139,8 +4198,8 @@ def llm_model_endpoint_sync( "_llm": { "model_name": "llama-7b", "source": "hugging_face", - "inference_framework": "deepspeed", - "inference_framework_image_tag": "123", + "inference_framework": "lightllm", + "inference_framework_image_tag": "0.9.4", "num_shards": 4, } }, @@ -3173,6 +4232,7 @@ def llm_model_endpoint_sync( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -3183,7 +4243,7 @@ def llm_model_endpoint_sync( post_inference_hooks=["callback"], default_callback_url="http://www.example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( + root=CallbackBasicAuth( kind="basic", username="test_username", password="test_password", @@ -3200,8 +4260,9 @@ def llm_model_endpoint_sync( "name": "test_llm_model_endpoint_name_1", "model_name": "llama-7b", "source": "hugging_face", - "inference_framework": "deepspeed", - "inference_framework_image_tag": "123", + "status": "READY", + "inference_framework": "lightllm", + "inference_framework_image_tag": "0.9.4", "num_shards": 4, "spec": { "id": "test_llm_model_endpoint_id_2", @@ -3213,8 +4274,142 @@ def llm_model_endpoint_sync( "_llm": { "model_name": "llama-7b", "source": "hugging_face", - "inference_framework": "deepspeed", - "inference_framework_image_tag": "123", + "inference_framework": "lightllm", + "inference_framework_image_tag": "0.9.4", + "num_shards": 4, + } + }, + "bundle_name": "test_model_bundle_name_1", + "status": "READY", + "post_inference_hooks": ["callback"], + "default_callback_url": "http://www.example.com", + "default_callback_auth": { + "kind": "basic", + "username": "test_username", + "password": "test_password", + }, + "labels": {}, + "aws_role": "test_aws_role", + "results_s3_bucket": "test_s3_bucket", + "created_by": test_api_key, + "created_at": "2022-01-03T00:00:00", + "last_updated_at": "2022-01-03T00:00:00", + "deployment_state": { + "min_workers": 1, + "max_workers": 3, + "per_worker": 2, + "available_workers": 1, + "unavailable_workers": 1, + }, + "resource_state": { + "cpus": 1, + "gpus": 1, + "memory": "1G", + "gpu_type": "nvidia-tesla-t4", + "storage": "10G", + "nodes_per_worker": 1, + "optimize_costs": True, + }, + "num_queued_items": 1, + "public_inference": True, + }, + } + return model_endpoint, model_endpoint_json + + +@pytest.fixture +def llm_model_endpoint_sync_trt_llm( + test_api_key: str, model_bundle_1: ModelBundle +) -> Tuple[ModelEndpoint, Any]: + model_endpoint = ModelEndpoint( + record=ModelEndpointRecord( + id="test_llm_model_endpoint_id_2", + name="test_llm_model_endpoint_name_1", + created_by=test_api_key, + created_at=datetime(2022, 1, 3), + last_updated_at=datetime(2022, 1, 3), + metadata={ + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "tensorrt_llm", + "inference_framework_image_tag": "0.9.4", + "num_shards": 4, + } + }, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.SYNC, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_1, + owner=test_api_key, + public_inference=True, + ), + infra_state=ModelEndpointInfraState( + deployment_name=f"{test_api_key}-test_llm_model_endpoint_name_1", + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + labels={}, + prewarm=True, + high_priority=False, + deployment_state=ModelEndpointDeploymentState( + min_workers=1, + max_workers=3, + per_worker=2, + available_workers=1, + unavailable_workers=1, + ), + resource_state=ModelEndpointResourceState( + cpus=1, + gpus=1, + memory="1G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + optimize_costs=True, + ), + user_config_state=ModelEndpointUserConfigState( + app_config=model_bundle_1.app_config, + endpoint_config=ModelEndpointConfig( + bundle_name=model_bundle_1.name, + endpoint_name="test_llm_model_endpoint_name_1", + post_inference_hooks=["callback"], + default_callback_url="http://www.example.com", + default_callback_auth=CallbackAuth( + root=CallbackBasicAuth( + kind="basic", + username="test_username", + password="test_password", + ), + ), + ), + ), + num_queued_items=1, + image="test_image", + ), + ) + model_endpoint_json: Dict[str, Any] = { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "model_name": "llama-7b", + "source": "hugging_face", + "status": "READY", + "inference_framework": "tensorrt_llm", + "inference_framework_image_tag": "0.9.4", + "num_shards": 4, + "spec": { + "id": "test_llm_model_endpoint_id_2", + "name": "test_llm_model_endpoint_name_1", + "endpoint_type": "sync", + "destination": "test_destination", + "deployment_name": f"{test_api_key}-test_llm_model_endpoint_name_1", + "metadata": { + "_llm": { + "model_name": "llama-7b", + "source": "hugging_face", + "inference_framework": "tensorrt_llm", + "inference_framework_image_tag": "0.9.4", "num_shards": 4, } }, @@ -3241,11 +4436,12 @@ def llm_model_endpoint_sync( "unavailable_workers": 1, }, "resource_state": { - "cpus": "1", + "cpus": 1, "gpus": 1, "memory": "1G", "gpu_type": "nvidia-tesla-t4", "storage": "10G", + "nodes_per_worker": 1, "optimize_costs": True, }, "num_queued_items": 1, @@ -3304,6 +4500,7 @@ def llm_model_endpoint_streaming(test_api_key: str, model_bundle_5: ModelBundle) memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( @@ -3336,7 +4533,7 @@ def llm_model_endpoint_text_generation_inference( "model_name": "llama-7b", "source": "hugging_face", "inference_framework": "text_generation_inference", - "inference_framework_image_tag": "123", + "inference_framework_image_tag": "0.9.4", "num_shards": 4, } }, @@ -3369,6 +4566,81 @@ def llm_model_endpoint_text_generation_inference( memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, + optimize_costs=True, + ), + user_config_state=ModelEndpointUserConfigState( + app_config=model_bundle_1.app_config, + endpoint_config=ModelEndpointConfig( + bundle_name=model_bundle_1.name, + endpoint_name="test_llm_model_endpoint_name_1", + post_inference_hooks=["callback"], + default_callback_url="http://www.example.com", + default_callback_auth=CallbackAuth( + root=CallbackBasicAuth( + kind="basic", + username="test_username", + password="test_password", + ), + ), + ), + ), + num_queued_items=1, + image="test_image", + ), + ) + + +@pytest.fixture +def llm_model_endpoint_trt_llm( + test_api_key: str, model_bundle_1: ModelBundle +) -> Tuple[ModelEndpoint, Any]: + return ModelEndpoint( + record=ModelEndpointRecord( + id="test_llm_model_endpoint_id_3", + name="test_llm_model_endpoint_name_trt_llm", + created_by=test_api_key, + created_at=datetime(2022, 1, 3), + last_updated_at=datetime(2022, 1, 3), + metadata={ + "_llm": { + "model_name": "llama-2-7b", + "source": "hugging_face", + "inference_framework": "tensorrt_llm", + "inference_framework_image_tag": "23.10", + "num_shards": 4, + } + }, + creation_task_id="test_creation_task_id", + endpoint_type=ModelEndpointType.STREAMING, + destination="test_destination", + status=ModelEndpointStatus.READY, + current_model_bundle=model_bundle_1, + owner=test_api_key, + public_inference=True, + ), + infra_state=ModelEndpointInfraState( + deployment_name=f"{test_api_key}-test_llm_model_endpoint_name_trt_llm", + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + child_fn_info=None, + labels={}, + prewarm=True, + high_priority=False, + deployment_state=ModelEndpointDeploymentState( + min_workers=1, + max_workers=3, + per_worker=2, + available_workers=1, + unavailable_workers=1, + ), + resource_state=ModelEndpointResourceState( + cpus=1, + gpus=1, + memory="1G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, optimize_costs=True, ), user_config_state=ModelEndpointUserConfigState( @@ -3379,7 +4651,7 @@ def llm_model_endpoint_text_generation_inference( post_inference_hooks=["callback"], default_callback_url="http://www.example.com", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( + root=CallbackBasicAuth( kind="basic", username="test_username", password="test_password", @@ -3391,3 +4663,85 @@ def llm_model_endpoint_text_generation_inference( image="test_image", ), ) + + +def mocked__get_recommended_hardware_config_map(): + async def async_mock(*args, **kwargs): # noqa + return { + "byGpuMemoryGb": """ + - gpu_memory_le: 20 + cpus: 5 + gpus: 1 + memory: 20Gi + storage: 40Gi + gpu_type: nvidia-hopper-h100-1g20gb + nodes_per_worker: 1 + - gpu_memory_le: 40 + cpus: 10 + gpus: 1 + memory: 40Gi + storage: 80Gi + gpu_type: nvidia-hopper-h100-3g40gb + nodes_per_worker: 1 + - gpu_memory_le: 80 + cpus: 20 + gpus: 1 + memory: 80Gi + storage: 96Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - gpu_memory_le: 160 + cpus: 40 + gpus: 2 + memory: 160Gi + storage: 160Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - gpu_memory_le: 320 + cpus: 80 + gpus: 4 + memory: 320Gi + storage: 320Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - gpu_memory_le: 640 + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - gpu_memory_le: 1280 + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 900Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 2 + """, + "byModelName": """ + - name: llama-3-8b-instruct-262k + cpus: 40 + gpus: 2 + memory: 160Gi + storage: 160Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - name: deepseek-coder-v2 + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + - name: deepseek-coder-v2-instruct + cpus: 160 + gpus: 8 + memory: 800Gi + storage: 640Gi + gpu_type: nvidia-hopper-h100 + nodes_per_worker: 1 + """, + } + + return mock.AsyncMock(side_effect=async_mock) diff --git a/model-engine/tests/unit/core/utils/test_timer.py b/model-engine/tests/unit/core/utils/test_timer.py new file mode 100644 index 00000000..f5d3b2d1 --- /dev/null +++ b/model-engine/tests/unit/core/utils/test_timer.py @@ -0,0 +1,15 @@ +import time + +from model_engine_server.core.utils.timer import timer + + +def test_timer(): + with timer() as t: + time.sleep(0.1) + lap_time = t.lap() + time.sleep(0.01) + new_lap_time = t.lap() + + assert new_lap_time >= 0.009 + assert lap_time >= 0.09 + assert t.duration >= 0.1 diff --git a/model-engine/tests/unit/domain/__init__.py b/model-engine/tests/unit/domain/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/model-engine/tests/unit/domain/conftest.py b/model-engine/tests/unit/domain/conftest.py new file mode 100644 index 00000000..f18808a7 --- /dev/null +++ b/model-engine/tests/unit/domain/conftest.py @@ -0,0 +1,696 @@ +import pytest +from model_engine_server.common.dtos.batch_jobs import ( + CreateDockerImageBatchJobBundleV1Request, + CreateDockerImageBatchJobResourceRequests, +) +from model_engine_server.common.dtos.llms import ( + CompletionStreamV1Request, + CompletionSyncV1Request, + CreateBatchCompletionsV1ModelConfig, + CreateBatchCompletionsV1Request, + CreateBatchCompletionsV1RequestContent, + CreateLLMModelEndpointV1Request, + UpdateLLMModelEndpointV1Request, +) +from model_engine_server.common.dtos.llms.batch_completion import ( + CreateBatchCompletionsV2ModelConfig, + CreateBatchCompletionsV2Request, + FilteredCompletionV2Request, +) +from model_engine_server.common.dtos.model_bundles import ( + CreateModelBundleV1Request, + CreateModelBundleV2Request, +) +from model_engine_server.common.dtos.model_endpoints import ( + CreateModelEndpointV1Request, + UpdateModelEndpointV1Request, +) +from model_engine_server.domain.entities import ( + GpuType, + LLMInferenceFramework, + ModelBundle, + ModelBundleEnvironmentParams, + ModelBundleFrameworkType, + ModelBundlePackagingType, + ModelEndpointType, + Quantization, + StreamingEnhancedRunnableImageFlavor, +) +from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( + CONVERTED_FROM_ARTIFACT_LIKE_KEY, +) + + +@pytest.fixture +def create_model_bundle_request() -> CreateModelBundleV1Request: + env_params = ModelBundleEnvironmentParams( + framework_type=ModelBundleFrameworkType.CUSTOM, + ecr_repo="test_repo", + image_tag="test_tag", + ) + return CreateModelBundleV1Request( + name="test_bundle_name", + location="test_location", + requirements=["numpy==0.0.0"], + env_params=env_params, + packaging_type=ModelBundlePackagingType.CLOUDPICKLE, + metadata=None, + app_config=None, + ) + + +@pytest.fixture +def create_model_bundle_v2_request() -> CreateModelBundleV2Request: + return CreateModelBundleV2Request( + name="test_bundle_name", + metadata=None, + schema_location="s3://test-bucket/test-key", + flavor=StreamingEnhancedRunnableImageFlavor( + flavor="streaming_enhanced_runnable_image", + repository="test_repo", + tag="test_tag", + command=["test_command"], + env={"test_key": "test_value"}, + protocol="http", + readiness_initial_delay_seconds=30, + streaming_command=["test_streaming_command"], + streaming_predict_route="/test_streaming_predict_route", + ), + ) + + +@pytest.fixture +def create_model_endpoint_request_sync( + model_bundle_1: ModelBundle, +) -> CreateModelEndpointV1Request: + return CreateModelEndpointV1Request( + name="test_endpoint_name_1", + model_bundle_id=model_bundle_1.id, + endpoint_type=ModelEndpointType.SYNC, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=1, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + ) + + +@pytest.fixture +def create_model_endpoint_request_streaming( + model_bundle_5: ModelBundle, +) -> CreateModelEndpointV1Request: + return CreateModelEndpointV1Request( + name="test_endpoint_name_2", + model_bundle_id=model_bundle_5.id, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=1, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=1, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + ) + + +@pytest.fixture +def create_model_endpoint_request_async( + model_bundle_1: ModelBundle, +) -> CreateModelEndpointV1Request: + return CreateModelEndpointV1Request( + name="test_endpoint_name_2", + model_bundle_id=model_bundle_1.id, + endpoint_type=ModelEndpointType.ASYNC, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=1, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + ) + + +@pytest.fixture +def update_model_endpoint_request( + model_bundle_2: ModelBundle, +) -> UpdateModelEndpointV1Request: + return UpdateModelEndpointV1Request( + model_bundle_id=model_bundle_2.id, + metadata={"test_new_key": "test_new_value"}, + cpus=2, + memory="16G", + min_workers=1, + max_workers=4, + per_worker=2, + ) + + +@pytest.fixture +def create_docker_image_batch_job_bundle_request() -> CreateDockerImageBatchJobBundleV1Request: + return CreateDockerImageBatchJobBundleV1Request( + name="name", + image_repository="repo", + image_tag="tag", + command=["sudo", "rn", "-rf"], + env=dict(hi="hi", bye="bye"), + mount_location=None, + resource_requests=CreateDockerImageBatchJobResourceRequests( + cpus=1, memory=None, storage=None, gpus=None, gpu_type=None + ), + ) + + +@pytest.fixture +def create_llm_model_endpoint_request_sync() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_sync", + model_name="mpt-7b", + source="hugging_face", + inference_framework="deepspeed", + inference_framework_image_tag="test_tag", + num_shards=2, + endpoint_type=ModelEndpointType.SYNC, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://mpt-7b", + ) + + +@pytest.fixture +def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_async", + model_name="mpt-7b", + source="hugging_face", + inference_framework="deepspeed", + inference_framework_image_tag="latest", + num_shards=2, + endpoint_type=ModelEndpointType.ASYNC, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + min_workers=0, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://llama-2-7b", + ) + + +@pytest.fixture +def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_streaming", + model_name="mpt-7b", + source="hugging_face", + inference_framework="deepspeed", + inference_framework_image_tag="test_tag", + num_shards=2, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://mpt-7b", + ) + + +@pytest.fixture +def update_llm_model_endpoint_request() -> UpdateLLMModelEndpointV1Request: + return UpdateLLMModelEndpointV1Request( + inference_framework_image_tag="latest", + checkpoint_path="s3://mpt-7b", + memory="4G", + min_workers=0, + max_workers=1, + ) + + +@pytest.fixture +def update_llm_model_endpoint_request_only_workers() -> UpdateLLMModelEndpointV1Request: + return UpdateLLMModelEndpointV1Request( + min_workers=5, + max_workers=10, + ) + + +@pytest.fixture +def update_llm_model_endpoint_request_bad_metadata() -> UpdateLLMModelEndpointV1Request: + return UpdateLLMModelEndpointV1Request(metadata={CONVERTED_FROM_ARTIFACT_LIKE_KEY: {}}) + + +@pytest.fixture +def create_llm_model_endpoint_request_llama_2() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_llama_2", + model_name="llama-2-7b", + source="hugging_face", + inference_framework="text_generation_inference", + inference_framework_image_tag="0.9.4", + num_shards=2, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://llama-2-7b", + ) + + +@pytest.fixture +def create_llm_model_endpoint_request_llama_3_70b() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_llama_3_70b", + model_name="llama-3-70b", + source="hugging_face", + inference_framework="vllm", + inference_framework_image_tag="1.0.0", + num_shards=2, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_HOPPER_H100, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://llama-3-70b", + ) + + +@pytest.fixture +def create_llm_model_endpoint_request_llama_3_70b_chat() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_llama_3_70b_chat", + model_name="llama-3-70b", + source="hugging_face", + inference_framework="vllm", + inference_framework_image_tag="1.0.0", + num_shards=2, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_HOPPER_H100, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://llama-3-70b", + chat_template_override="test-template", + ) + + +@pytest.fixture +def create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args() -> ( + CreateLLMModelEndpointV1Request +): + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_llama_3_70b_chat", + model_name="llama-3-70b", + source="hugging_face", + inference_framework="vllm", + inference_framework_image_tag="1.0.0", + num_shards=2, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_HOPPER_H100, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://llama-3-70b", + chat_template_override="test-template", + max_model_len=1000, + max_num_seqs=10, + ) + + +@pytest.fixture +def create_llm_model_endpoint_request_llama_3_1_405b_instruct() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_llama_3_1_405b_instruct", + model_name="llama-3-1-405b-instruct", + source="hugging_face", + inference_framework="vllm", + inference_framework_image_tag="1.0.0", + num_shards=8, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=8, + memory="8G", + gpu_type=GpuType.NVIDIA_HOPPER_H100, + storage="10G", + nodes_per_worker=2, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://llama-3-1-405b-instruct", + ) + + +@pytest.fixture +def create_llm_model_endpoint_text_generation_inference_request_streaming() -> ( + CreateLLMModelEndpointV1Request +): + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_tgi_streaming", + model_name="mpt-7b", + source="hugging_face", + inference_framework="deepspeed", + inference_framework_image_tag="test_tag", + num_shards=2, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://mpt-7b", + ) + + +@pytest.fixture +def create_llm_model_endpoint_text_generation_inference_request_async() -> ( + CreateLLMModelEndpointV1Request +): + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_tgi_async", + model_name="mpt-7b", + source="hugging_face", + inference_framework="text_generation_inference", + inference_framework_image_tag="0.9.4", + num_shards=2, + quantize=Quantization.BITSANDBYTES, + endpoint_type=ModelEndpointType.ASYNC, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + ) + + +@pytest.fixture +def create_llm_model_endpoint_trt_llm_request_streaming() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_trt_llm_streaming", + model_name="llama-2-7b", + source="hugging_face", + inference_framework="tensorrt_llm", + inference_framework_image_tag="23.10", + num_shards=2, + endpoint_type=ModelEndpointType.STREAMING, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://test_checkpoint_path", + ) + + +@pytest.fixture +def create_llm_model_endpoint_trt_llm_request_async() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_tgi_async", + model_name="llama-2-7b", + source="hugging_face", + inference_framework="tensorrt_llm", + inference_framework_image_tag="23.10", + num_shards=2, + quantize=Quantization.BITSANDBYTES, + endpoint_type=ModelEndpointType.ASYNC, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + checkpoint_path="s3://test_checkpoint_path", + ) + + +@pytest.fixture +def create_llm_model_endpoint_request_invalid_model_name() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_1", + model_name="nonexist", + source="hugging_face", + inference_framework="deepspeed", + inference_framework_image_tag="test_tag", + num_shards=2, + endpoint_type=ModelEndpointType.SYNC, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + ) + + +@pytest.fixture +def create_llm_model_endpoint_request_invalid_quantization() -> CreateLLMModelEndpointV1Request: + return CreateLLMModelEndpointV1Request( + name="test_llm_endpoint_name_1", + model_name="nonexist", + source="hugging_face", + inference_framework=LLMInferenceFramework.VLLM, + inference_framework_image_tag="test_tag", + num_shards=2, + quantize=Quantization.BITSANDBYTES, + endpoint_type=ModelEndpointType.SYNC, + metadata={}, + post_inference_hooks=["billing"], + cpus=1, + gpus=2, + memory="8G", + gpu_type=GpuType.NVIDIA_TESLA_T4, + storage="10G", + nodes_per_worker=1, + min_workers=1, + max_workers=3, + per_worker=2, + labels={"team": "infra", "product": "my_product"}, + aws_role="test_aws_role", + results_s3_bucket="test_s3_bucket", + ) + + +@pytest.fixture +def completion_sync_request() -> CompletionSyncV1Request: + return CompletionSyncV1Request( + prompt="What is machine learning?", + max_new_tokens=10, + temperature=0.5, + return_token_log_probs=True, + ) + + +@pytest.fixture +def completion_stream_request() -> CompletionStreamV1Request: + return CompletionStreamV1Request( + prompt="What is machine learning?", + max_new_tokens=10, + temperature=0.5, + ) + + +@pytest.fixture +def create_batch_completions_v1_request() -> CreateBatchCompletionsV1Request: + return CreateBatchCompletionsV1Request( + input_data_path="test_input_data_path", + output_data_path="test_output_data_path", + content=CreateBatchCompletionsV1RequestContent( + prompts=["What is machine learning?"], + max_new_tokens=10, + temperature=0.5, + ), + model_config=CreateBatchCompletionsV1ModelConfig( + model="mpt-7b", + checkpoint_path="s3://test_checkpoint_path", + labels={}, + num_shards=1, + ), + data_parallelism=1, + ) + + +@pytest.fixture +def create_batch_completions_v2_request() -> CreateBatchCompletionsV2Request: + return CreateBatchCompletionsV2Request( + output_data_path="test_output_data_path", + content=[ + FilteredCompletionV2Request( + prompt="What is machine learning?", + max_tokens=10, + temperature=0.5, + ) + ], + model_config=CreateBatchCompletionsV2ModelConfig( + model="mpt-7b", + checkpoint_path="s3://test_checkpoint_path", + labels={}, + num_shards=1, + ), + data_parallelism=1, + ) + + +@pytest.fixture +def create_batch_completions_v2_request_with_hardware() -> CreateBatchCompletionsV2Request: + return CreateBatchCompletionsV2Request( + output_data_path="test_output_data_path", + content=[ + FilteredCompletionV2Request( + prompt="What is machine learning?", + max_tokens=10, + temperature=0.5, + ) + ], + model_config=CreateBatchCompletionsV2ModelConfig( + model="mpt-7b", + checkpoint_path="s3://test_checkpoint_path", + labels={}, + num_shards=1, + ), + data_parallelism=1, + cpus=1, + gpus=1, + memory="8G", + gpu_type=GpuType.NVIDIA_HOPPER_H100, + storage="10G", + ) diff --git a/server/tests/unit/domain/test_async_inference_use_cases.py b/model-engine/tests/unit/domain/test_async_inference_use_cases.py similarity index 91% rename from server/tests/unit/domain/test_async_inference_use_cases.py rename to model-engine/tests/unit/domain/test_async_inference_use_cases.py index 9b027907..4334480d 100644 --- a/server/tests/unit/domain/test_async_inference_use_cases.py +++ b/model-engine/tests/unit/domain/test_async_inference_use_cases.py @@ -1,14 +1,14 @@ from typing import Any, Dict, Tuple import pytest -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request, TaskStatus -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request, TaskStatus +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.domain.exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.entities import ModelEndpoint -from llm_engine_server.domain.use_cases.async_inference_use_cases import ( +from model_engine_server.domain.use_cases.async_inference_use_cases import ( CreateAsyncInferenceTaskV1UseCase, GetAsyncInferenceTaskV1UseCase, ) diff --git a/server/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py b/model-engine/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py similarity index 92% rename from server/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py rename to model-engine/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py index 51852f30..4f62b79d 100644 --- a/server/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py +++ b/model-engine/tests/unit/domain/test_docker_image_batch_job_bundle_use_cases.py @@ -1,16 +1,16 @@ import pytest -from llm_engine_server.common.dtos.batch_jobs import ( +from model_engine_server.common.dtos.batch_jobs import ( CreateDockerImageBatchJobBundleV1Request, CreateDockerImageBatchJobBundleV1Response, ) -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.repositories import DockerImageBatchJobBundleRepository -from llm_engine_server.domain.use_cases.docker_image_batch_job_bundle_use_cases import ( +from model_engine_server.domain.repositories import DockerImageBatchJobBundleRepository +from model_engine_server.domain.use_cases.docker_image_batch_job_bundle_use_cases import ( CreateDockerImageBatchJobBundleV1UseCase, GetDockerImageBatchJobBundleByIdV1UseCase, GetLatestDockerImageBatchJobBundleByNameV1UseCase, @@ -56,9 +56,7 @@ async def test_create_list_docker_image_batch_job_bundle_use_case( user=user, request=create_docker_image_batch_job_bundle_request ) response_2 = await use_case_list.execute( - user=user, - bundle_name=create_docker_image_batch_job_bundle_request.name, - order_by=None, + user=user, bundle_name=create_docker_image_batch_job_bundle_request.name, order_by=None ) assert len(response_2.docker_image_batch_job_bundles) == 1 assert ( @@ -103,9 +101,7 @@ async def test_create_list_docker_image_batch_job_bundle_team_use_case( ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) user_other_team_1 = User( - user_id=test_api_key_user_on_other_team, - team_id=test_api_key_team, - is_privileged_user=True, + user_id=test_api_key_user_on_other_team, team_id=test_api_key_team, is_privileged_user=True ) user_other_team_2 = User( user_id=test_api_key_user_on_other_team_2, @@ -121,9 +117,7 @@ async def test_create_list_docker_image_batch_job_bundle_team_use_case( ) await use_case_create.execute(user=user, request=create_docker_image_batch_job_bundle_request) response_2 = await use_case_list.execute( - user=user, - bundle_name=create_docker_image_batch_job_bundle_request.name, - order_by=None, + user=user, bundle_name=create_docker_image_batch_job_bundle_request.name, order_by=None ) assert len(response_2.docker_image_batch_job_bundles) == 1 response_3 = await use_case_list.execute( @@ -209,9 +203,7 @@ async def test_create_get_docker_image_batch_job_bundle_by_id_unauthorized_use_c ): user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) user_other_team_1 = User( - user_id=test_api_key_user_on_other_team, - team_id=test_api_key_team, - is_privileged_user=True, + user_id=test_api_key_user_on_other_team, team_id=test_api_key_team, is_privileged_user=True ) use_case_create = CreateDockerImageBatchJobBundleV1UseCase( docker_image_batch_job_bundle_repo=fake_docker_image_batch_job_bundle_repository diff --git a/server/tests/unit/domain/test_entities.py b/model-engine/tests/unit/domain/test_entities.py similarity index 79% rename from server/tests/unit/domain/test_entities.py rename to model-engine/tests/unit/domain/test_entities.py index dc9a8e56..cd0ab507 100644 --- a/server/tests/unit/domain/test_entities.py +++ b/model-engine/tests/unit/domain/test_entities.py @@ -1,5 +1,5 @@ import pytest -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( CallbackAuth, CallbackBasicAuth, ModelBundle, @@ -15,6 +15,7 @@ bundle_name="test_bundle", post_inference_hooks=["test_hook"], user_id="test_user", + billing_queue="test_queue", default_callback_url="test_url", ), ModelEndpointConfig( @@ -22,17 +23,14 @@ bundle_name="test_bundle", post_inference_hooks=["test_hook"], user_id="test_user", + billing_queue="test_queue", default_callback_auth=CallbackAuth( - __root__=CallbackBasicAuth( - kind="basic", username="test_user", password="test_password" - ) + root=CallbackBasicAuth(kind="basic", username="test_user", password="test_password") ), ), ], ) -def test_model_endpoint_config_serialization( - model_endpoint_config: ModelEndpointConfig, -): +def test_model_endpoint_config_serialization(model_endpoint_config: ModelEndpointConfig): serialized_config = model_endpoint_config.serialize() deserialized_config = ModelEndpointConfig.deserialize(serialized_config) assert model_endpoint_config == deserialized_config diff --git a/model-engine/tests/unit/domain/test_llm_use_cases.py b/model-engine/tests/unit/domain/test_llm_use_cases.py new file mode 100644 index 00000000..f1392168 --- /dev/null +++ b/model-engine/tests/unit/domain/test_llm_use_cases.py @@ -0,0 +1,2997 @@ +import json +from typing import Any, List, Tuple +from unittest import mock + +import pytest +from model_engine_server.common.dtos.batch_jobs import CreateDockerImageBatchJobResourceRequests +from model_engine_server.common.dtos.llms import ( + CompletionOutput, + CompletionStreamV1Request, + CompletionSyncV1Request, + CreateBatchCompletionsV1Request, + CreateFineTuneRequest, + CreateLLMModelEndpointV1Request, + CreateLLMModelEndpointV1Response, + ModelDownloadRequest, + TokenOutput, + UpdateLLMModelEndpointV1Request, +) +from model_engine_server.common.dtos.llms.batch_completion import ( + CreateBatchCompletionsEngineRequest, + CreateBatchCompletionsV2Request, +) +from model_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.entities import ( + LLMInferenceFramework, + ModelEndpoint, + ModelEndpointType, +) +from model_engine_server.domain.exceptions import ( + DockerImageNotFoundException, + EndpointUnsupportedInferenceTypeException, + InvalidRequestException, + LLMFineTuningQuotaReached, + ObjectHasInvalidValueException, + ObjectNotAuthorizedException, + ObjectNotFoundException, + UpstreamServiceError, +) +from model_engine_server.domain.use_cases.llm_fine_tuning_use_cases import ( + MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER, + CreateFineTuneV1UseCase, + GetFineTuneEventsV1UseCase, + is_model_name_suffix_valid, +) +from model_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( + CHAT_TEMPLATE_MAX_LENGTH, + CompletionStreamV1UseCase, + CompletionSyncV1UseCase, + CreateBatchCompletionsUseCase, + CreateBatchCompletionsV2UseCase, + CreateLLMModelBundleV1UseCase, + CreateLLMModelEndpointV1UseCase, + DeleteLLMEndpointByNameUseCase, + GetLLMModelEndpointByNameV1UseCase, + GpuType, + ModelDownloadV1UseCase, + UpdateLLMModelEndpointV1UseCase, + _fill_hardware_info, + _infer_hardware, + merge_metadata, + validate_and_update_completion_params, + validate_chat_template, + validate_checkpoint_files, +) +from model_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase + +from ..conftest import mocked__get_recommended_hardware_config_map + + +def mocked__get_latest_batch_v2_tag(): + async def async_mock(*args, **kwargs): # noqa + return "fake_docker_repository_latest_image_tag" + + return mock.AsyncMock(side_effect=async_mock) + + +def mocked__get_latest_batch_tag(): + async def async_mock(*args, **kwargs): # noqa + return "fake_docker_repository_latest_image_tag" + + return mock.AsyncMock(side_effect=async_mock) + + +def mocked__get_latest_tag(): + async def async_mock(*args, **kwargs): # noqa + return "fake_docker_repository_latest_image_tag" + + return mock.AsyncMock(side_effect=async_mock) + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_tag", + mocked__get_latest_tag(), +) +async def test_create_model_endpoint_use_case_success( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_request_async: CreateLLMModelEndpointV1Request, + create_llm_model_endpoint_request_sync: CreateLLMModelEndpointV1Request, + create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, + create_llm_model_endpoint_request_llama_2: CreateLLMModelEndpointV1Request, + create_llm_model_endpoint_request_llama_3_70b: CreateLLMModelEndpointV1Request, + create_llm_model_endpoint_request_llama_3_1_405b_instruct: CreateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_async) + assert response_1.endpoint_creation_task_id + assert isinstance(response_1, CreateLLMModelEndpointV1Response) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_async.name, + order_by=None, + ) + )[0] + assert endpoint.record.endpoint_type == ModelEndpointType.ASYNC + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_request_async.model_name, + "source": create_llm_model_endpoint_request_async.source, + "inference_framework": create_llm_model_endpoint_request_async.inference_framework, + "inference_framework_image_tag": "fake_docker_repository_latest_image_tag", + "num_shards": create_llm_model_endpoint_request_async.num_shards, + "quantize": None, + "checkpoint_path": create_llm_model_endpoint_request_async.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_async.chat_template_override, + } + } + + response_2 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_sync) + assert response_2.endpoint_creation_task_id + assert isinstance(response_2, CreateLLMModelEndpointV1Response) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_sync.name, + order_by=None, + ) + )[0] + assert endpoint.record.endpoint_type == ModelEndpointType.SYNC + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_request_sync.model_name, + "source": create_llm_model_endpoint_request_sync.source, + "inference_framework": create_llm_model_endpoint_request_sync.inference_framework, + "inference_framework_image_tag": create_llm_model_endpoint_request_sync.inference_framework_image_tag, + "num_shards": create_llm_model_endpoint_request_sync.num_shards, + "quantize": None, + "checkpoint_path": create_llm_model_endpoint_request_sync.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_sync.chat_template_override, + } + } + + response_3 = await use_case.execute( + user=user, request=create_llm_model_endpoint_request_streaming + ) + assert response_3.endpoint_creation_task_id + assert isinstance(response_3, CreateLLMModelEndpointV1Response) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_streaming.name, + order_by=None, + ) + )[0] + assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_request_streaming.model_name, + "source": create_llm_model_endpoint_request_streaming.source, + "inference_framework": create_llm_model_endpoint_request_streaming.inference_framework, + "inference_framework_image_tag": create_llm_model_endpoint_request_streaming.inference_framework_image_tag, + "num_shards": create_llm_model_endpoint_request_streaming.num_shards, + "quantize": None, + "checkpoint_path": create_llm_model_endpoint_request_streaming.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_streaming.chat_template_override, + } + } + + response_4 = await use_case.execute( + user=user, request=create_llm_model_endpoint_request_llama_2 + ) + assert response_4.endpoint_creation_task_id + assert isinstance(response_4, CreateLLMModelEndpointV1Response) + bundle = await fake_model_bundle_repository.get_latest_model_bundle_by_name( + owner=user.team_id, name=create_llm_model_endpoint_request_llama_2.name + ) + assert "--max-total-tokens" in bundle.flavor.command[-1] and "4096" in bundle.flavor.command[-1] + + response_5 = await use_case.execute( + user=user, request=create_llm_model_endpoint_request_llama_3_70b + ) + assert response_5.endpoint_creation_task_id + assert isinstance(response_5, CreateLLMModelEndpointV1Response) + bundle = await fake_model_bundle_repository.get_latest_model_bundle_by_name( + owner=user.team_id, name=create_llm_model_endpoint_request_llama_3_70b.name + ) + assert " --gpu-memory-utilization 0.95" in bundle.flavor.command[-1] + + response_6 = await use_case.execute( + user=user, request=create_llm_model_endpoint_request_llama_3_1_405b_instruct + ) + assert response_6.endpoint_creation_task_id + assert isinstance(response_6, CreateLLMModelEndpointV1Response) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_llama_3_1_405b_instruct.name, + order_by=None, + ) + )[0] + assert endpoint.infra_state.resource_state.nodes_per_worker == 2 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "inference_framework, model_name, checkpoint_path, expected_error", + [ + ( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + "mpt-7b", + None, + InvalidRequestException, + ), + ( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, + "mpt-7b-instruct", + "gibberish", + ObjectHasInvalidValueException, + ), + (LLMInferenceFramework.LIGHTLLM, "mpt-7b", None, InvalidRequestException), + ( + LLMInferenceFramework.LIGHTLLM, + "mpt-7b-instruct", + "gibberish", + ObjectHasInvalidValueException, + ), + (LLMInferenceFramework.VLLM, "mpt-7b", None, InvalidRequestException), + ( + LLMInferenceFramework.VLLM, + "mpt-7b-instruct", + "gibberish", + ObjectHasInvalidValueException, + ), + ], +) +async def test_create_model_bundle_fails_if_no_checkpoint( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_text_generation_inference_request_streaming: CreateLLMModelEndpointV1Request, + inference_framework, + model_name, + checkpoint_path, + expected_error, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() + + with pytest.raises(expected_error): + await use_case.execute( + user=user, + endpoint_name=request.name, + model_name=model_name, + source=request.source, + framework=inference_framework, + framework_image_tag="0.0.0", + endpoint_type=request.endpoint_type, + num_shards=request.num_shards, + quantize=request.quantize, + checkpoint_path=checkpoint_path, + chat_template_override=request.chat_template_override, + nodes_per_worker=1, + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "valid, inference_framework, inference_framework_image_tag", + [ + (False, LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "0.9.2"), + (True, LLMInferenceFramework.TEXT_GENERATION_INFERENCE, "0.9.3"), + (False, LLMInferenceFramework.VLLM, "0.1.6"), + (True, LLMInferenceFramework.VLLM, "0.1.3.6"), + ], +) +async def test_create_model_bundle_inference_framework_image_tag_validation( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_docker_repository_image_never_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_text_generation_inference_request_streaming: CreateLLMModelEndpointV1Request, + valid, + inference_framework, + inference_framework_image_tag, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + + request = create_llm_model_endpoint_text_generation_inference_request_streaming.copy() + request.inference_framework = inference_framework + request.inference_framework_image_tag = inference_framework_image_tag + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + if valid: + await use_case.execute(user=user, request=request) + else: + llm_bundle_use_case.docker_repository = fake_docker_repository_image_never_exists + with pytest.raises(DockerImageNotFoundException): + await use_case.execute(user=user, request=request) + + +@pytest.mark.asyncio +async def test_create_model_endpoint_w_chat_template( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_request_llama_3_70b_chat: CreateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response = await use_case.execute( + user=user, + request=create_llm_model_endpoint_request_llama_3_70b_chat, + ) + assert response.endpoint_creation_task_id + assert isinstance(response, CreateLLMModelEndpointV1Response) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_llama_3_70b_chat.name, + order_by=None, + ) + )[0] + + assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_request_llama_3_70b_chat.model_name, + "source": create_llm_model_endpoint_request_llama_3_70b_chat.source, + "inference_framework": create_llm_model_endpoint_request_llama_3_70b_chat.inference_framework, + "inference_framework_image_tag": create_llm_model_endpoint_request_llama_3_70b_chat.inference_framework_image_tag, + "num_shards": create_llm_model_endpoint_request_llama_3_70b_chat.num_shards, + "quantize": create_llm_model_endpoint_request_llama_3_70b_chat.quantize, + "checkpoint_path": create_llm_model_endpoint_request_llama_3_70b_chat.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_llama_3_70b_chat.chat_template_override, + } + } + + +@pytest.mark.asyncio +async def test_create_model_endpoint_w_vllm_args( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args: CreateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response = await use_case.execute( + user=user, + request=create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args, + ) + assert response.endpoint_creation_task_id + assert isinstance(response, CreateLLMModelEndpointV1Response) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.name, + order_by=None, + ) + )[0] + + bundle_command = endpoint.record.current_model_bundle.flavor.command[2] + expected_vllm_args = ["max-model-len", "max-num-seqs", "chat-template"] + for arg in expected_vllm_args: + assert arg in bundle_command + assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.model_name, + "source": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.source, + "inference_framework": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.inference_framework, + "inference_framework_image_tag": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.inference_framework_image_tag, + "num_shards": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.num_shards, + "quantize": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.quantize, + "checkpoint_path": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_llama_3_70b_chat_vllm_args.chat_template_override, + } + } + + +@pytest.mark.asyncio +async def test_create_model_endpoint_text_generation_inference_use_case_success( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_text_generation_inference_request_async: CreateLLMModelEndpointV1Request, + create_llm_model_endpoint_text_generation_inference_request_streaming: CreateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + request=create_llm_model_endpoint_text_generation_inference_request_streaming, + ) + assert response_1.endpoint_creation_task_id + assert isinstance(response_1, CreateLLMModelEndpointV1Response) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_text_generation_inference_request_streaming.name, + order_by=None, + ) + )[0] + assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_text_generation_inference_request_streaming.model_name, + "source": create_llm_model_endpoint_text_generation_inference_request_streaming.source, + "inference_framework": create_llm_model_endpoint_text_generation_inference_request_streaming.inference_framework, + "inference_framework_image_tag": create_llm_model_endpoint_text_generation_inference_request_streaming.inference_framework_image_tag, + "num_shards": create_llm_model_endpoint_text_generation_inference_request_streaming.num_shards, + "quantize": create_llm_model_endpoint_text_generation_inference_request_streaming.quantize, + "checkpoint_path": create_llm_model_endpoint_text_generation_inference_request_streaming.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_text_generation_inference_request_streaming.chat_template_override, + } + } + + with pytest.raises(ObjectHasInvalidValueException): + await use_case.execute( + user=user, + request=create_llm_model_endpoint_text_generation_inference_request_async, + ) + + +def test_load_model_weights_sub_commands( + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + + framework = LLMInferenceFramework.VLLM + framework_image_tag = "0.2.7" + checkpoint_path = "s3://fake-checkpoint" + final_weights_folder = "test_folder" + + subcommands = llm_bundle_use_case.load_model_weights_sub_commands( + framework, framework_image_tag, checkpoint_path, final_weights_folder + ) + + expected_result = [ + './s5cmd --numworkers 512 cp --concurrency 10 --include "*.model" --include "*.json" --include "*.safetensors" --exclude "optimizer*" s3://fake-checkpoint/* test_folder', + ] + assert expected_result == subcommands + + trust_remote_code = True + subcommands = llm_bundle_use_case.load_model_weights_sub_commands( + framework, framework_image_tag, checkpoint_path, final_weights_folder, trust_remote_code + ) + + expected_result = [ + './s5cmd --numworkers 512 cp --concurrency 10 --include "*.model" --include "*.json" --include "*.safetensors" --exclude "optimizer*" --include "*.py" s3://fake-checkpoint/* test_folder', + ] + assert expected_result == subcommands + + framework = LLMInferenceFramework.TEXT_GENERATION_INFERENCE + framework_image_tag = "1.0.0" + checkpoint_path = "s3://fake-checkpoint" + final_weights_folder = "test_folder" + + subcommands = llm_bundle_use_case.load_model_weights_sub_commands( + framework, framework_image_tag, checkpoint_path, final_weights_folder + ) + + expected_result = [ + "s5cmd > /dev/null || conda install -c conda-forge -y s5cmd", + 's5cmd --numworkers 512 cp --concurrency 10 --include "*.model" --include "*.json" --include "*.safetensors" --exclude "optimizer*" s3://fake-checkpoint/* test_folder', + ] + assert expected_result == subcommands + + framework = LLMInferenceFramework.VLLM + framework_image_tag = "0.2.7" + checkpoint_path = "azure://fake-checkpoint" + final_weights_folder = "test_folder" + + subcommands = llm_bundle_use_case.load_model_weights_sub_commands( + framework, framework_image_tag, checkpoint_path, final_weights_folder + ) + + expected_result = [ + "export AZCOPY_AUTO_LOGIN_TYPE=WORKLOAD", + "curl -L https://aka.ms/downloadazcopy-v10-linux | tar --strip-components=1 -C /usr/local/bin --no-same-owner --exclude=*.txt -xzvf - && chmod 755 /usr/local/bin/azcopy", + 'azcopy copy --recursive --include-pattern "*.model;*.json;*.safetensors" --exclude-pattern "optimizer*" azure://fake-checkpoint/* test_folder', + ] + assert expected_result == subcommands + + trust_remote_code = True + subcommands = llm_bundle_use_case.load_model_weights_sub_commands( + framework, framework_image_tag, checkpoint_path, final_weights_folder, trust_remote_code + ) + + expected_result = [ + "export AZCOPY_AUTO_LOGIN_TYPE=WORKLOAD", + "curl -L https://aka.ms/downloadazcopy-v10-linux | tar --strip-components=1 -C /usr/local/bin --no-same-owner --exclude=*.txt -xzvf - && chmod 755 /usr/local/bin/azcopy", + 'azcopy copy --recursive --include-pattern "*.model;*.json;*.safetensors;*.py" --exclude-pattern "optimizer*" azure://fake-checkpoint/* test_folder', + ] + assert expected_result == subcommands + + +@pytest.mark.asyncio +async def test_create_model_endpoint_trt_llm_use_case_success( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_trt_llm_request_async: CreateLLMModelEndpointV1Request, + create_llm_model_endpoint_trt_llm_request_streaming: CreateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + request=create_llm_model_endpoint_trt_llm_request_streaming, + ) + assert response_1.endpoint_creation_task_id + assert isinstance(response_1, CreateLLMModelEndpointV1Response) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_trt_llm_request_streaming.name, + order_by=None, + ) + )[0] + assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_trt_llm_request_streaming.model_name, + "source": create_llm_model_endpoint_trt_llm_request_streaming.source, + "inference_framework": create_llm_model_endpoint_trt_llm_request_streaming.inference_framework, + "inference_framework_image_tag": create_llm_model_endpoint_trt_llm_request_streaming.inference_framework_image_tag, + "num_shards": create_llm_model_endpoint_trt_llm_request_streaming.num_shards, + "quantize": create_llm_model_endpoint_trt_llm_request_streaming.quantize, + "checkpoint_path": create_llm_model_endpoint_trt_llm_request_streaming.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_trt_llm_request_streaming.chat_template_override, + } + } + + with pytest.raises(ObjectHasInvalidValueException): + await use_case.execute( + user=user, + request=create_llm_model_endpoint_trt_llm_request_async, + ) + + +@pytest.mark.asyncio +async def test_create_llm_model_endpoint_use_case_quantization_exception( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + create_llm_model_endpoint_request_invalid_quantization: CreateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + with pytest.raises(ObjectHasInvalidValueException): + await use_case.execute( + user=user, request=create_llm_model_endpoint_request_invalid_quantization + ) + + +@pytest.mark.asyncio +async def test_get_llm_model_endpoint_use_case_raises_not_found( + test_api_key: str, + fake_llm_model_endpoint_service, + llm_model_endpoint_async: Tuple[ModelEndpoint, Any], +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_async[0]) + use_case = GetLLMModelEndpointByNameV1UseCase( + llm_model_endpoint_service=fake_llm_model_endpoint_service + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + with pytest.raises(ObjectNotFoundException): + await use_case.execute(user=user, model_endpoint_name="invalid_model_endpoint_name") + + +@pytest.mark.asyncio +async def test_get_llm_model_endpoint_use_case_raises_not_authorized( + test_api_key: str, + fake_llm_model_endpoint_service, + llm_model_endpoint_async: Tuple[ModelEndpoint, Any], +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_async[0]) + use_case = GetLLMModelEndpointByNameV1UseCase( + llm_model_endpoint_service=fake_llm_model_endpoint_service + ) + llm_model_endpoint_async[0].record.public_inference = False + user = User(user_id="non_exist", team_id="non_exist", is_privileged_user=False) + with pytest.raises(ObjectNotAuthorizedException): + await use_case.execute( + user=user, model_endpoint_name=llm_model_endpoint_async[0].record.name + ) + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_tag", + mocked__get_latest_tag(), +) +async def test_update_model_endpoint_use_case_success( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + fake_llm_model_endpoint_service, + create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, + update_llm_model_endpoint_request: UpdateLLMModelEndpointV1Request, + update_llm_model_endpoint_request_only_workers: UpdateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + create_use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + update_use_case = UpdateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + ) + + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + + await create_use_case.execute(user=user, request=create_llm_model_endpoint_request_streaming) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_streaming.name, + order_by=None, + ) + )[0] + fake_llm_model_endpoint_service.add_model_endpoint(endpoint) + + update_response = await update_use_case.execute( + user=user, + model_endpoint_name=create_llm_model_endpoint_request_streaming.name, + request=update_llm_model_endpoint_request, + ) + assert update_response.endpoint_creation_task_id + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_streaming.name, + order_by=None, + ) + )[0] + assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_request_streaming.model_name, + "source": create_llm_model_endpoint_request_streaming.source, + "inference_framework": create_llm_model_endpoint_request_streaming.inference_framework, + "inference_framework_image_tag": "fake_docker_repository_latest_image_tag", + "num_shards": create_llm_model_endpoint_request_streaming.num_shards, + "quantize": None, + "checkpoint_path": update_llm_model_endpoint_request.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_streaming.chat_template_override, + } + } + assert endpoint.infra_state.resource_state.memory == update_llm_model_endpoint_request.memory + assert ( + endpoint.infra_state.deployment_state.min_workers + == update_llm_model_endpoint_request.min_workers + ) + assert ( + endpoint.infra_state.deployment_state.max_workers + == update_llm_model_endpoint_request.max_workers + ) + + update_response2 = await update_use_case.execute( + user=user, + model_endpoint_name=create_llm_model_endpoint_request_streaming.name, + request=update_llm_model_endpoint_request_only_workers, + ) + assert update_response2.endpoint_creation_task_id + + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_streaming.name, + order_by=None, + ) + )[0] + assert endpoint.record.metadata == { + "_llm": { + "model_name": create_llm_model_endpoint_request_streaming.model_name, + "source": create_llm_model_endpoint_request_streaming.source, + "inference_framework": create_llm_model_endpoint_request_streaming.inference_framework, + "inference_framework_image_tag": "fake_docker_repository_latest_image_tag", + "num_shards": create_llm_model_endpoint_request_streaming.num_shards, + "quantize": None, + "checkpoint_path": update_llm_model_endpoint_request.checkpoint_path, + "chat_template_override": create_llm_model_endpoint_request_streaming.chat_template_override, + } + } + assert endpoint.infra_state.resource_state.memory == update_llm_model_endpoint_request.memory + assert ( + endpoint.infra_state.deployment_state.min_workers + == update_llm_model_endpoint_request_only_workers.min_workers + ) + assert ( + endpoint.infra_state.deployment_state.max_workers + == update_llm_model_endpoint_request_only_workers.max_workers + ) + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_tag", + mocked__get_latest_tag(), +) +async def test_update_model_endpoint_use_case_failure( + test_api_key: str, + fake_model_bundle_repository, + fake_model_endpoint_service, + fake_docker_repository_image_always_exists, + fake_model_primitive_gateway, + fake_llm_artifact_gateway, + fake_llm_model_endpoint_service, + create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, + update_llm_model_endpoint_request_bad_metadata: UpdateLLMModelEndpointV1Request, +): + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + bundle_use_case = CreateModelBundleV2UseCase( + model_bundle_repository=fake_model_bundle_repository, + docker_repository=fake_docker_repository_image_always_exists, + model_primitive_gateway=fake_model_primitive_gateway, + ) + llm_bundle_use_case = CreateLLMModelBundleV1UseCase( + create_model_bundle_use_case=bundle_use_case, + model_bundle_repository=fake_model_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + docker_repository=fake_docker_repository_image_always_exists, + ) + create_use_case = CreateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + update_use_case = UpdateLLMModelEndpointV1UseCase( + create_llm_model_bundle_use_case=llm_bundle_use_case, + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + docker_repository=fake_docker_repository_image_always_exists, + ) + + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + + await create_use_case.execute(user=user, request=create_llm_model_endpoint_request_streaming) + endpoint = ( + await fake_model_endpoint_service.list_model_endpoints( + owner=None, + name=create_llm_model_endpoint_request_streaming.name, + order_by=None, + ) + )[0] + fake_llm_model_endpoint_service.add_model_endpoint(endpoint) + + with pytest.raises(ObjectHasInvalidValueException): + await update_use_case.execute( + user=user, + model_endpoint_name=create_llm_model_endpoint_request_streaming.name, + request=update_llm_model_endpoint_request_bad_metadata, + ) + + +def mocked_auto_tokenizer_from_pretrained(*args, **kwargs): # noqa + class mocked_encode: + def encode(self, input: str) -> List[Any]: # noqa + return [1] * 7 + + return mocked_encode() + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.AutoTokenizer.from_pretrained", + mocked_auto_tokenizer_from_pretrained, +) +async def test_completion_sync_use_case_success( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + completion_sync_request: CompletionSyncV1Request, +): + completion_sync_request.include_stop_str_in_output = True + completion_sync_request.guided_json = {} + completion_sync_request.skip_special_tokens = False + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": json.dumps( + { + "text": "I am a newbie to the world of programming.", + "tokens": [ + "I", + " am", + " a", + " new", + "bie", + " to", + " the", + " world", + " of", + " programming", + ".", + ], + "log_probs": [ + {1: -2.3025850929940455}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + {1: 0}, + ], + "count_prompt_tokens": 7, + "count_output_tokens": 11, + } + ) + }, + traceback=None, + ) + ) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_sync[0].record.name, + request=completion_sync_request, + ) + assert response_1.output == CompletionOutput( + text="I am a newbie to the world of programming.", + num_prompt_tokens=7, + num_completion_tokens=11, + tokens=[ + TokenOutput(token="I", log_prob=-2.3025850929940455), + TokenOutput(token=" am", log_prob=0), + TokenOutput(token=" a", log_prob=0), + TokenOutput(token=" new", log_prob=0), + TokenOutput(token="bie", log_prob=0), + TokenOutput(token=" to", log_prob=0), + TokenOutput(token=" the", log_prob=0), + TokenOutput(token=" world", log_prob=0), + TokenOutput(token=" of", log_prob=0), + TokenOutput(token=" programming", log_prob=0), + TokenOutput(token=".", log_prob=0), + ], + ) + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=5, +) +async def test_completion_sync_text_generation_inference_use_case_success( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_text_generation_inference: ModelEndpoint, + completion_sync_request: CompletionSyncV1Request, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_text_generation_inference) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": """ + { + "generated_text": " Deep Learning is a new type of machine learning", + "details": { + "finish_reason": "length", + "generated_tokens": 9, + "prefill": [ + { + "id": 10560, + "text": "What" + }, + { + "id": 632, + "text": " is" + }, + { + "id": 89554, + "text": " Deep" + }, + { + "id": 89950, + "text": " Learning" + }, + { + "id": 34, + "text": "?" + } + ], + "tokens": [ + { + "text": " Deep", + "logprob": 0 + }, + { + "text": " Learning", + "logprob": -1 + }, + { + "text": " is", + "logprob": 0 + }, + { + "text": " a", + "logprob": 0 + }, + { + "text": " new", + "logprob": 0 + }, + { + "text": " type", + "logprob": 0 + }, + { + "text": " of", + "logprob": 0 + }, + { + "text": " machine", + "logprob": 0 + }, + { + "text": " learning", + "logprob": 0 + } + ] + } + } +""" + }, + traceback=None, + ) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_text_generation_inference.record.name, + request=completion_sync_request, + ) + assert response_1.output == CompletionOutput( + text=" Deep Learning is a new type of machine learning", + num_prompt_tokens=5, + num_completion_tokens=9, + tokens=[ + TokenOutput(token=" Deep", log_prob=0.0), + TokenOutput(token=" Learning", log_prob=-1.0), + TokenOutput(token=" is", log_prob=0.0), + TokenOutput(token=" a", log_prob=0.0), + TokenOutput(token=" new", log_prob=0.0), + TokenOutput(token=" type", log_prob=0.0), + TokenOutput(token=" of", log_prob=0.0), + TokenOutput(token=" machine", log_prob=0.0), + TokenOutput(token=" learning", log_prob=0.0), + ], + ) + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=6, +) +async def test_completion_sync_trt_llm_use_case_success_23_10( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_trt_llm: ModelEndpoint, + completion_sync_request: CompletionSyncV1Request, +): + completion_sync_request.return_token_log_probs = False # not yet supported + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_trt_llm) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": '{"model_name": "ensemble", "model_version": "1", "sequence_end": false, "sequence_id": 0, "sequence_start": false, "text_output": " What is machine learning? Machine learning is a branch", "token_ids": [1, 1724, 338, 4933, 6509, 29973, 6189, 6509, 338, 263, 5443]}' + }, + traceback=None, + ) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_trt_llm.record.name, + request=completion_sync_request, + ) + assert response_1.output == CompletionOutput( + text=" Machine learning is a branch", + num_prompt_tokens=6, + num_completion_tokens=5, + ) + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=6, +) +@pytest.mark.parametrize( + "output_log_probs,output_tokens", [("[0.0,0.0,0.0,0.0,0.0]", 5), ("0.0", 1)] +) +async def test_completion_sync_trt_llm_use_case_success_24_01( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_trt_llm: ModelEndpoint, + completion_sync_request: CompletionSyncV1Request, + output_log_probs: str, + output_tokens: int, +): + completion_sync_request.return_token_log_probs = False # not yet supported + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_trt_llm) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": f'{{"context_logits":0.0,"cum_log_probs":0.0,"generation_logits":0.0,"model_name":"ensemble","model_version":"1","output_log_probs":{output_log_probs},"sequence_end":false,"sequence_id":0,"sequence_start":false,"text_output":" Machine learning is a branch"}}' + }, + traceback=None, + ) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_trt_llm.record.name, + request=completion_sync_request, + ) + assert response_1.output == CompletionOutput( + text=" Machine learning is a branch", + num_prompt_tokens=6, + num_completion_tokens=output_tokens, + ) + + +@pytest.mark.asyncio +async def test_completion_sync_use_case_predict_failed( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + completion_sync_request: CompletionSyncV1Request, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( + SyncEndpointPredictV1Response( + status=TaskStatus.FAILURE, + result=None, + traceback="failed to predict", + ) + ) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + with pytest.raises(UpstreamServiceError): + await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_sync[0].record.name, + request=completion_sync_request, + ) + + +@pytest.mark.asyncio +async def test_completion_sync_use_case_predict_failed_lightllm( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_sync_lightllm: Tuple[ModelEndpoint, Any], + completion_sync_request: CompletionSyncV1Request, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_lightllm[0]) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( + SyncEndpointPredictV1Response( + status=TaskStatus.FAILURE, + result=None, + traceback="failed to predict", + ) + ) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + with pytest.raises(UpstreamServiceError): + await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_sync_lightllm[0].record.name, + request=completion_sync_request, + ) + + +@pytest.mark.asyncio +async def test_completion_sync_use_case_predict_failed_trt_llm( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_sync_trt_llm: Tuple[ModelEndpoint, Any], + completion_sync_request: CompletionSyncV1Request, +): + completion_sync_request.return_token_log_probs = False # not yet supported + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_trt_llm[0]) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( + SyncEndpointPredictV1Response( + status=TaskStatus.FAILURE, + result=None, + traceback="failed to predict", + ) + ) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + with pytest.raises(UpstreamServiceError): + await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_sync_trt_llm[0].record.name, + request=completion_sync_request, + ) + + +@pytest.mark.asyncio +async def test_completion_sync_use_case_predict_failed_with_errors( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_sync_tgi: Tuple[ModelEndpoint, Any], + completion_sync_request: CompletionSyncV1Request, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync_tgi[0]) + fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": """ + { + "error": "Request failed during generation: Server error: transport error", + "error_type": "generation" + } +""" + }, + traceback="failed to predict", + ) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + with pytest.raises(UpstreamServiceError): + await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_sync_tgi[0].record.name, + request=completion_sync_request, + ) + + +@pytest.mark.asyncio +async def test_completion_sync_use_case_not_sync_endpoint_raises( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_async: Tuple[ModelEndpoint, Any], + completion_sync_request: CompletionSyncV1Request, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_async[0]) + use_case = CompletionSyncV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + with pytest.raises(EndpointUnsupportedInferenceTypeException): + await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_async[0].record.name, + request=completion_sync_request, + ) + + +@pytest.mark.asyncio +async def test_validate_and_update_completion_params(): + completion_sync_request = CompletionSyncV1Request( + prompt="What is machine learning?", + max_new_tokens=10, + temperature=0.5, + return_token_log_probs=True, + ) + + validate_and_update_completion_params(LLMInferenceFramework.VLLM, completion_sync_request) + + validate_and_update_completion_params( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, completion_sync_request + ) + + completion_sync_request.include_stop_str_in_output = True + with pytest.raises(ObjectHasInvalidValueException): + validate_and_update_completion_params( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, completion_sync_request + ) + completion_sync_request.include_stop_str_in_output = None + + completion_sync_request.guided_regex = "" + completion_sync_request.guided_json = {} + completion_sync_request.guided_choice = [""] + completion_sync_request.guided_grammar = "" + with pytest.raises(ObjectHasInvalidValueException): + validate_and_update_completion_params(LLMInferenceFramework.VLLM, completion_sync_request) + + completion_sync_request.guided_regex = None + completion_sync_request.guided_choice = None + completion_sync_request.guided_grammar = None + with pytest.raises(ObjectHasInvalidValueException): + validate_and_update_completion_params( + LLMInferenceFramework.TEXT_GENERATION_INFERENCE, completion_sync_request + ) + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=7, +) +async def test_completion_stream_use_case_success( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_streaming: ModelEndpoint, + completion_stream_request: CompletionStreamV1Request, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_streaming) + fake_model_endpoint_service.streaming_model_endpoint_inference_gateway.responses = [ + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"token": "I"}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"token": " am"}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"token": " a"}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"token": " new"}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"token": "bie"}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"token": "."}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "response": [ + { + "error": None, + "text": "I am a newbie.", + "token_probs": { + "tokens": [ + "I", + " am", + " a", + " new", + "bie", + ".", + ] + }, + "tokens_consumed": 25, + } + ] + } + }, + traceback=None, + ), + ] + use_case = CompletionStreamV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_streaming.record.name, + request=completion_stream_request, + ) + output_texts = ["I", " am", " a", " new", "bie", ".", "I am a newbie."] + i = 0 + async for message in response_1: + assert message.dict()["output"]["text"] == output_texts[i] + if i == 6: + assert message.dict()["output"]["num_prompt_tokens"] == 7 + assert message.dict()["output"]["num_completion_tokens"] == 6 + i += 1 + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=7, +) +async def test_completion_stream_vllm_use_case_success( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_stream: Tuple[ModelEndpoint, Any], + completion_stream_request: CompletionStreamV1Request, +): + completion_stream_request.guided_json = {} + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_stream[0]) + fake_model_endpoint_service.streaming_model_endpoint_inference_gateway.responses = [ + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": "I", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 1, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": " am", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 2, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": " a", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 3, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": " new", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 4, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": "bie", + "finished": False, + "count_prompt_tokens": 7, + "count_output_tokens": 5, + } + }, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={ + "result": { + "text": ".", + "finished": True, + "count_prompt_tokens": 7, + "count_output_tokens": 6, + } + }, + traceback=None, + ), + ] + use_case = CompletionStreamV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_stream[0].record.name, + request=completion_stream_request, + ) + output_texts = ["I", " am", " a", " new", "bie", ".", "I am a newbie."] + i = 0 + async for message in response_1: + assert message.dict()["output"]["text"] == output_texts[i] + if i == 5: + assert message.dict()["output"]["num_prompt_tokens"] == 7 + assert message.dict()["output"]["num_completion_tokens"] == 6 + i += 1 + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=7, +) +async def test_completion_stream_text_generation_inference_use_case_success( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_text_generation_inference: ModelEndpoint, + completion_stream_request: CompletionStreamV1Request, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_text_generation_inference) + fake_model_endpoint_service.streaming_model_endpoint_inference_gateway.responses = [ + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"token": {"text": "I"}}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"token": {"text": " am"}}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"token": {"text": " a"}}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"token": {"text": " new"}}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"token": {"text": "bie"}}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"token": {"text": "."}, "generated_text": "I am a newbie."}}, + traceback=None, + ), + ] + use_case = CompletionStreamV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_text_generation_inference.record.name, + request=completion_stream_request, + ) + output_texts = ["I", " am", " a", " new", "bie", ".", "I am a newbie."] + i = 0 + async for message in response_1: + assert message.dict()["output"]["text"] == output_texts[i] + if i == 5: + assert message.dict()["output"]["num_prompt_tokens"] == 7 + assert message.dict()["output"]["num_completion_tokens"] == 6 + i += 1 + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases.count_tokens", + return_value=7, +) +async def test_completion_stream_trt_llm_use_case_success( + test_api_key: str, + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + fake_tokenizer_repository, + llm_model_endpoint_trt_llm: ModelEndpoint, + completion_stream_request: CompletionStreamV1Request, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_trt_llm) + fake_model_endpoint_service.streaming_model_endpoint_inference_gateway.responses = [ + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "Machine", "token_ids": 6189}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "learning", "token_ids": 6509}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "is", "token_ids": 338}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "a", "token_ids": 263}}, + traceback=None, + ), + SyncEndpointPredictV1Response( + status=TaskStatus.SUCCESS, + result={"result": {"text_output": "branch", "token_ids": 5443}}, + traceback=None, + ), + ] + use_case = CompletionStreamV1UseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + tokenizer_repository=fake_tokenizer_repository, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response_1 = await use_case.execute( + user=user, + model_endpoint_name=llm_model_endpoint_trt_llm.record.name, + request=completion_stream_request, + ) + output_texts = ["Machine", "learning", "is", "a", "branch"] + i = 0 + async for message in response_1: + assert message.dict()["output"]["text"] == output_texts[i] + assert message.dict()["output"]["num_prompt_tokens"] == 7 + assert message.dict()["output"]["num_completion_tokens"] == i + 1 + i += 1 + + +@pytest.mark.asyncio +async def test_create_llm_fine_tune_model_name_valid(): + assert is_model_name_suffix_valid("model-name") + assert not is_model_name_suffix_valid("Hi There! This is an invalid model name.") + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_fine_tuning_use_cases.smart_open.open", + mock.mock_open(read_data="prompt,response"), +) +async def test_create_fine_tune_success( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + test_api_key: str, +): + use_case = CreateFineTuneV1UseCase( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + request = CreateFineTuneRequest( + model="base_model", + training_file="file1", + validation_file=None, + # fine_tuning_method="lora", + hyperparameters={}, + suffix=None, + ) + response = await use_case.execute(user=user, request=request) + assert response.id + + # This erroring code is part of the service anyways + # request.suffix = "Invalid model suffix *&^&%^$^%&^*" + # with pytest.raises(InvalidRequestException): + # await use_case.execute(user=user, request=request) + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_fine_tuning_use_cases.smart_open.open", + mock.mock_open(read_data="prompt,response"), +) +async def test_create_fine_tune_limit( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + test_api_key: str, +): + use_case = CreateFineTuneV1UseCase( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=False) + request = CreateFineTuneRequest( + model="base_model", + training_file="file1", + validation_file=None, + # fine_tuning_method="lora", + hyperparameters={}, + suffix=None, + ) + for i in range(MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER): + if i == MAX_LLM_ENDPOINTS_PER_EXTERNAL_USER: + with pytest.raises(LLMFineTuningQuotaReached): + await use_case.execute(user=user, request=request) + else: + response = await use_case.execute(user=user, request=request) + assert response.id + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_fine_tuning_use_cases.smart_open.open", + mock.mock_open(read_data="prompt,response"), +) +async def test_create_fine_tune_long_suffix( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + test_api_key: str, +): + use_case = CreateFineTuneV1UseCase( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + request = CreateFineTuneRequest( + model="base_model", + training_file="file1", + validation_file=None, + # fine_tuning_method="lora", + hyperparameters={}, + suffix="a" * 100, + ) + with pytest.raises(InvalidRequestException): + await use_case.execute(user=user, request=request) + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_fine_tuning_use_cases.smart_open.open", + mock.mock_open(read_data="prompt,not_response"), +) +async def test_create_fine_tune_invalid_headers( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + test_api_key: str, +): + use_case = CreateFineTuneV1UseCase( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + request = CreateFineTuneRequest( + model="base_model", + training_file="file1", + validation_file=None, + # fine_tuning_method="lora", + hyperparameters={}, + suffix=None, + ) + with pytest.raises(InvalidRequestException): + await use_case.execute(user=user, request=request) + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_fine_tuning_use_cases.smart_open.open", + mock.mock_open(read_data="prompt,response"), +) +async def test_get_fine_tune_events_success( + fake_llm_fine_tuning_service, + fake_llm_fine_tuning_events_repository, + fake_model_endpoint_service, + fake_file_storage_gateway, + test_api_key: str, +): + populate_use_case = CreateFineTuneV1UseCase( + fake_llm_fine_tuning_service, + fake_model_endpoint_service, + fake_llm_fine_tuning_events_repository, + fake_file_storage_gateway, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + request = CreateFineTuneRequest( + model="base_model", + training_file="file1", + validation_file=None, + # fine_tuning_method="lora", + hyperparameters={}, + suffix=None, + ) + response = await populate_use_case.execute(user=user, request=request) + + use_case = GetFineTuneEventsV1UseCase( + llm_fine_tune_events_repository=fake_llm_fine_tuning_events_repository, + llm_fine_tuning_service=fake_llm_fine_tuning_service, + ) + response_2 = await use_case.execute(user=user, fine_tune_id=response.id) + assert len(response_2.events) == len(fake_llm_fine_tuning_events_repository.all_events_list) + + +@pytest.mark.asyncio +async def test_download_model_success( + fake_model_endpoint_service, + fake_filesystem_gateway, + fake_llm_artifact_gateway, + model_endpoint_1: ModelEndpoint, + test_api_key: str, +): + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + model_endpoint_1.record.owner = test_api_key + model_endpoint_1.record.name = "base_model" + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_llm_artifact_gateway._add_model(user.team_id, model_endpoint_1.record.name) + use_case = ModelDownloadV1UseCase( + fake_filesystem_gateway, fake_model_endpoint_service, fake_llm_artifact_gateway + ) + request = ModelDownloadRequest( + model_name=model_endpoint_1.record.name, + download_format="huggingface", + ) + response = await use_case.execute(user=user, request=request) + assert response.urls != {} + + +@pytest.mark.asyncio +async def test_download_nonexistent_model_raises_not_found( + fake_model_endpoint_service, + fake_filesystem_gateway, + fake_llm_artifact_gateway, + model_endpoint_1: ModelEndpoint, + test_api_key: str, +): + model_endpoint_1.record.owner = test_api_key + model_endpoint_1.record.name = "base_model" + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_llm_artifact_gateway._add_model(test_api_key, "base_model") + use_case = ModelDownloadV1UseCase( + fake_filesystem_gateway, fake_model_endpoint_service, fake_llm_artifact_gateway + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + request = ModelDownloadRequest( + model_name="nonexistent_model", + download_format="huggingface", + ) + with pytest.raises(ObjectNotFoundException): + await use_case.execute(user=user, request=request) + + +@pytest.mark.asyncio +async def test_delete_model_success( + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + test_api_key: str, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + fake_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + use_case = DeleteLLMEndpointByNameUseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + response = await use_case.execute( + user=user, model_endpoint_name=llm_model_endpoint_sync[0].record.name + ) + remaining_endpoint_model_service = await fake_model_endpoint_service.get_model_endpoint( + llm_model_endpoint_sync[0].record.id + ) + assert remaining_endpoint_model_service is None + assert response.deleted is True + + +@pytest.mark.asyncio +async def test_delete_nonexistent_model_raises_not_found( + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + test_api_key: str, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + fake_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + use_case = DeleteLLMEndpointByNameUseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + with pytest.raises(ObjectNotFoundException): + await use_case.execute(user=user, model_endpoint_name="nonexistent-model") + + +@pytest.mark.asyncio +async def test_delete_unauthorized_model_raises_not_authorized( + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + fake_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + use_case = DeleteLLMEndpointByNameUseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + user = User(user_id="fakeapikey", team_id="fakeapikey", is_privileged_user=True) + with pytest.raises(ObjectNotAuthorizedException): + await use_case.execute( + user=user, model_endpoint_name=llm_model_endpoint_sync[0].record.name + ) + + +@pytest.mark.asyncio +async def test_delete_public_inference_model_raises_not_authorized( + fake_model_endpoint_service, + fake_llm_model_endpoint_service, + llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], + test_api_key, +): + fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + fake_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) + use_case = DeleteLLMEndpointByNameUseCase( + model_endpoint_service=fake_model_endpoint_service, + llm_model_endpoint_service=fake_llm_model_endpoint_service, + ) + user = User( + user_id="fakeapikey", team_id="faketeam", is_privileged_user=True + ) # write access is based on team_id, so team_id != owner's team_id + with pytest.raises( + ObjectNotAuthorizedException + ): # user cannot delete public inference model they don't own + await use_case.execute( + user=user, model_endpoint_name=llm_model_endpoint_sync[0].record.name + ) + + +@pytest.mark.asyncio +async def test_validate_checkpoint_files_no_safetensors(): + fake_model_files = ["model-fake.bin", "model.json", "optimizer.pt"] + with pytest.raises(ObjectHasInvalidValueException): + validate_checkpoint_files(fake_model_files) + + +@pytest.mark.asyncio +async def test_validate_checkpoint_files_safetensors_with_other_files(): + fake_model_files = [ + "model-fake.bin", + "model-fake2.safetensors", + "model.json", + "optimizer.pt", + ] + validate_checkpoint_files(fake_model_files) # No exception should be raised + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_recommended_hardware_config_map", + mocked__get_recommended_hardware_config_map(), +) +async def test_infer_hardware(fake_llm_artifact_gateway): + # deepseek from https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Instruct/raw/main/config.json + fake_llm_artifact_gateway.model_config = { + "architectures": ["DeepseekV2ForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "aux_loss_alpha": 0.001, + "bos_token_id": 100000, + "eos_token_id": 100001, + "ep_size": 1, + "first_k_dense_replace": 1, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 12288, + "kv_lora_rank": 512, + "max_position_embeddings": 163840, + "model_type": "deepseek_v2", + "moe_intermediate_size": 1536, + "moe_layer_freq": 1, + "n_group": 8, + "n_routed_experts": 160, + "n_shared_experts": 2, + "norm_topk_prob": False, + "num_attention_heads": 128, + "num_experts_per_tok": 6, + "num_hidden_layers": 60, + "num_key_value_heads": 128, + "pretraining_tp": 1, + "q_lora_rank": 1536, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "rms_norm_eps": 1e-06, + "rope_theta": 10000, + "routed_scaling_factor": 16.0, + "scoring_func": "softmax", + "seq_aux": True, + "tie_word_embeddings": False, + "topk_group": 3, + "topk_method": "group_limited_greedy", + "torch_dtype": "bfloat16", + "transformers_version": "4.39.3", + "use_cache": True, + "v_head_dim": 128, + "vocab_size": 102400, + } + + hardware = await _infer_hardware(fake_llm_artifact_gateway, "deepseek-coder-v2-instruct", "") + assert hardware.cpus == 160 + assert hardware.gpus == 8 + assert hardware.memory == "800Gi" + assert hardware.storage == "640Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "deepseek-coder-v2-instruct", "", is_batch_job=True + ) + assert hardware.cpus == 160 + assert hardware.gpus == 8 + assert hardware.memory == "800Gi" + assert hardware.storage == "640Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + # deepseek lite https://huggingface.co/deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct/raw/main/config.json + fake_llm_artifact_gateway.model_config = { + "architectures": ["DeepseekV2ForCausalLM"], + "attention_bias": False, + "attention_dropout": 0.0, + "aux_loss_alpha": 0.001, + "bos_token_id": 100000, + "eos_token_id": 100001, + "first_k_dense_replace": 1, + "hidden_act": "silu", + "hidden_size": 2048, + "initializer_range": 0.02, + "intermediate_size": 10944, + "kv_lora_rank": 512, + "max_position_embeddings": 163840, + "model_type": "deepseek_v2", + "moe_intermediate_size": 1408, + "moe_layer_freq": 1, + "n_group": 1, + "n_routed_experts": 64, + "n_shared_experts": 2, + "norm_topk_prob": False, + "num_attention_heads": 16, + "num_experts_per_tok": 6, + "num_hidden_layers": 27, + "num_key_value_heads": 16, + "pretraining_tp": 1, + "qk_nope_head_dim": 128, + "qk_rope_head_dim": 64, + "rms_norm_eps": 1e-06, + "rope_theta": 10000, + "routed_scaling_factor": 1.0, + "scoring_func": "softmax", + "seq_aux": True, + "tie_word_embeddings": False, + "topk_group": 1, + "topk_method": "greedy", + "torch_dtype": "bfloat16", + "transformers_version": "4.39.3", + "use_cache": True, + "v_head_dim": 128, + "vocab_size": 102400, + } + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "deepseek-coder-v2-lite-instruct", "" + ) + assert hardware.cpus == 20 + assert hardware.gpus == 1 + assert hardware.memory == "80Gi" + assert hardware.storage == "96Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, + "deepseek-coder-v2-lite-instruct", + "", + is_batch_job=True, + ) + assert hardware.cpus == 160 + assert hardware.gpus == 8 + assert hardware.memory == "800Gi" + assert hardware.storage == "640Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, + "deepseek-coder-v2-lite-instruct", + "", + is_batch_job=True, + max_context_length=4096, + ) + assert hardware.cpus == 20 + assert hardware.gpus == 1 + assert hardware.memory == "80Gi" + assert hardware.storage == "96Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + + # Phi 3 mini from https://huggingface.co/microsoft/Phi-3-mini-4k-instruct/blob/main/config.json + fake_llm_artifact_gateway.model_config = { + "architectures": ["Phi3ForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 1, + "embd_pdrop": 0.0, + "eos_token_id": 32000, + "hidden_act": "silu", + "hidden_size": 3072, + "initializer_range": 0.02, + "intermediate_size": 8192, + "max_position_embeddings": 4096, + "model_type": "phi3", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "original_max_position_embeddings": 4096, + "pad_token_id": 32000, + "resid_pdrop": 0.0, + "rms_norm_eps": 1e-05, + "rope_theta": 10000.0, + "sliding_window": 2047, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.2", + "use_cache": True, + "attention_bias": False, + "vocab_size": 32064, + } + + hardware = await _infer_hardware(fake_llm_artifact_gateway, "phi-3-mini-4k-instruct", "") + assert hardware.cpus == 5 + assert hardware.gpus == 1 + assert hardware.memory == "20Gi" + assert hardware.storage == "40Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "phi-3-mini-4k-instruct", "", is_batch_job=True + ) + assert hardware.cpus == 10 + assert hardware.gpus == 1 + assert hardware.memory == "40Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + assert hardware.nodes_per_worker == 1 + + # Phi 3 small from https://huggingface.co/microsoft/Phi-3-small-8k-instruct/blob/main/config.json + fake_llm_artifact_gateway.model_config = { + "architectures": ["Phi3SmallForCausalLM"], + "attention_dropout_prob": 0.0, + "blocksparse_block_size": 64, + "blocksparse_homo_head_pattern": False, + "blocksparse_num_local_blocks": 16, + "blocksparse_triton_kernel_block_size": 64, + "blocksparse_vert_stride": 8, + "bos_token_id": 100257, + "dense_attention_every_n_layers": 2, + "embedding_dropout_prob": 0.1, + "eos_token_id": 100257, + "ff_dim_multiplier": None, + "ff_intermediate_size": 14336, + "ffn_dropout_prob": 0.1, + "gegelu_limit": 20.0, + "gegelu_pad_to_256": True, + "hidden_act": "gegelu", + "hidden_size": 4096, + "initializer_range": 0.02, + "layer_norm_epsilon": 1e-05, + "max_position_embeddings": 8192, + "model_type": "phi3small", + "mup_attn_multiplier": 1.0, + "mup_embedding_multiplier": 10.0, + "mup_use_scaling": True, + "mup_width_multiplier": 8.0, + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pad_sequence_to_multiple_of_64": True, + "reorder_and_upcast_attn": False, + "rope_embedding_base": 1000000, + "rope_position_scale": 1.0, + "torch_dtype": "bfloat16", + "transformers_version": "4.38.1", + "use_cache": True, + "attention_bias": False, + "vocab_size": 100352, + } + + hardware = await _infer_hardware(fake_llm_artifact_gateway, "phi-3-small-8k-instruct", "") + print(hardware) + assert hardware.cpus == 5 + assert hardware.gpus == 1 + assert hardware.memory == "20Gi" + assert hardware.storage == "40Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "phi-3-small-8k-instruct", "", is_batch_job=True + ) + print(hardware) + assert hardware.cpus == 10 + assert hardware.gpus == 1 + assert hardware.memory == "40Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + assert hardware.nodes_per_worker == 1 + + fake_llm_artifact_gateway.model_config = { + "architectures": ["Phi3ForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 1, + "embd_pdrop": 0.0, + "eos_token_id": 32000, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 17920, + "max_position_embeddings": 4096, + "model_type": "phi3", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 10, + "original_max_position_embeddings": 4096, + "pad_token_id": 32000, + "resid_pdrop": 0.0, + "rms_norm_eps": 1e-05, + "rope_scaling": None, + "rope_theta": 10000.0, + "sliding_window": 2047, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.39.3", + "use_cache": True, + "attention_bias": False, + "vocab_size": 32064, + } + + hardware = await _infer_hardware(fake_llm_artifact_gateway, "phi-3-medium-8k-instruct", "") + assert hardware.cpus == 10 + assert hardware.gpus == 1 + assert hardware.memory == "40Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "phi-3-medium-8k-instruct", "", is_batch_job=True + ) + assert hardware.cpus == 20 + assert hardware.gpus == 1 + assert hardware.memory == "80Gi" + assert hardware.storage == "96Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + fake_llm_artifact_gateway.model_config = { + "architectures": ["MixtralForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 32768, + "model_type": "mixtral", + "num_attention_heads": 32, + "num_experts_per_tok": 2, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "num_local_experts": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000.0, + "router_aux_loss_coef": 0.02, + "torch_dtype": "bfloat16", + "transformers_version": "4.36.0.dev0", + "vocab_size": 32000, + } + hardware = await _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x7b", "") + assert hardware.cpus == 40 + assert hardware.gpus == 2 + assert hardware.memory == "160Gi" + assert hardware.storage == "160Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "mixtral-8x7b", "", is_batch_job=True + ) + assert hardware.cpus == 40 + assert hardware.gpus == 2 + assert hardware.memory == "160Gi" + assert hardware.storage == "160Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + fake_llm_artifact_gateway.model_config = { + "architectures": ["MixtralForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 6144, + "initializer_range": 0.02, + "intermediate_size": 16384, + "max_position_embeddings": 65536, + "model_type": "mixtral", + "num_attention_heads": 48, + "num_experts_per_tok": 2, + "num_hidden_layers": 56, + "num_key_value_heads": 8, + "num_local_experts": 8, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000, + "router_aux_loss_coef": 0.001, + "router_jitter_noise": 0.0, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "vocab_size": 32000, + } + hardware = await _infer_hardware(fake_llm_artifact_gateway, "mixtral-8x22b", "") + assert hardware.cpus == 160 + assert hardware.gpus == 8 + assert hardware.memory == "800Gi" + assert hardware.storage == "640Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "mixtral-8x22b", "", is_batch_job=True + ) + assert hardware.cpus == 160 + assert hardware.gpus == 8 + assert hardware.memory == "800Gi" + assert hardware.storage == "640Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + fake_llm_artifact_gateway.model_config = { + "_name_or_path": "meta-llama/Llama-2-7b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 11008, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 32, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "torch_dtype": "float16", + "transformers_version": "4.31.0.dev0", + "vocab_size": 32000, + } + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-7b", "") + assert hardware.cpus == 5 + assert hardware.gpus == 1 + assert hardware.memory == "20Gi" + assert hardware.storage == "40Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-7b", "", is_batch_job=True) + assert hardware.cpus == 10 + assert hardware.gpus == 1 + assert hardware.memory == "40Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + assert hardware.nodes_per_worker == 1 + + fake_llm_artifact_gateway.model_config = { + "architectures": ["LlamaForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_theta": 500000.0, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "vocab_size": 128256, + } + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b", "") + assert hardware.cpus == 5 + assert hardware.gpus == 1 + assert hardware.memory == "20Gi" + assert hardware.storage == "40Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_1G_20GB + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b", "", is_batch_job=True) + assert hardware.cpus == 10 + assert hardware.gpus == 1 + assert hardware.memory == "40Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + assert hardware.nodes_per_worker == 1 + + fake_llm_artifact_gateway.model_config = { + "_name_or_path": "meta-llama/Llama-2-13b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 5120, + "initializer_range": 0.02, + "intermediate_size": 13824, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 40, + "num_hidden_layers": 40, + "num_key_value_heads": 40, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "torch_dtype": "float16", + "transformers_version": "4.32.0.dev0", + "vocab_size": 32000, + } + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-13b", "") + assert hardware.cpus == 10 + assert hardware.gpus == 1 + assert hardware.memory == "40Gi" + assert hardware.storage == "80Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100_3G_40GB + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "llama-2-13b", "", is_batch_job=True + ) + assert hardware.cpus == 20 + assert hardware.gpus == 1 + assert hardware.memory == "80Gi" + assert hardware.storage == "96Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + fake_llm_artifact_gateway.model_config = { + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 22016, + "max_position_embeddings": 16384, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 48, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_theta": 1000000, + "torch_dtype": "bfloat16", + "transformers_version": "4.32.0.dev0", + "vocab_size": 32000, + } + hardware = await _infer_hardware(fake_llm_artifact_gateway, "codellama-34b", "") + assert hardware.cpus == 20 + assert hardware.gpus == 1 + assert hardware.memory == "80Gi" + assert hardware.storage == "96Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "codellama-34b", "", is_batch_job=True + ) + assert hardware.cpus == 40 + assert hardware.gpus == 2 + assert hardware.memory == "160Gi" + assert hardware.storage == "160Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + fake_llm_artifact_gateway.model_config = { + "_name_or_path": "meta-llama/Llama-2-70b-hf", + "architectures": ["LlamaForCausalLM"], + "bos_token_id": 1, + "eos_token_id": 2, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 4096, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "torch_dtype": "float16", + "transformers_version": "4.32.0.dev0", + "vocab_size": 32000, + } + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-2-70b", "") + assert hardware.cpus == 40 + assert hardware.gpus == 2 + assert hardware.memory == "160Gi" + assert hardware.storage == "160Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "llama-2-70b", "", is_batch_job=True + ) + assert hardware.cpus == 80 + assert hardware.gpus == 4 + assert hardware.memory == "320Gi" + assert hardware.storage == "320Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + fake_llm_artifact_gateway.model_config = { + "architectures": ["LlamaForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 28672, + "max_position_embeddings": 8192, + "model_type": "llama", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_theta": 500000.0, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.0.dev0", + "vocab_size": 128256, + } + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-70b", "") + assert hardware.cpus == 40 + assert hardware.gpus == 2 + assert hardware.memory == "160Gi" + assert hardware.storage == "160Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + hardware = await _infer_hardware( + fake_llm_artifact_gateway, "llama-3-70b", "", is_batch_job=True + ) + assert hardware.cpus == 80 + assert hardware.gpus == 4 + assert hardware.memory == "320Gi" + assert hardware.storage == "320Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + fake_llm_artifact_gateway.model_config = { + "_name_or_path": "gradientai/llama3-8b-stage65k-chat", + "architectures": ["LlamaForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 128000, + "eos_token_id": 128001, + "hidden_act": "silu", + "hidden_size": 4096, + "initializer_range": 0.02, + "intermediate_size": 14336, + "max_position_embeddings": 262144, + "model_type": "llama", + "num_attention_heads": 32, + "num_hidden_layers": 32, + "num_key_value_heads": 8, + "pretraining_tp": 1, + "rms_norm_eps": 1e-05, + "rope_theta": 283461213.0, + "torch_dtype": "bfloat16", + "transformers_version": "4.41.0.dev0", + "vocab_size": 128256, + } + hardware = await _infer_hardware(fake_llm_artifact_gateway, "llama-3-8b-instruct-262k", "") + assert hardware.cpus == 40 + assert hardware.gpus == 2 + assert hardware.memory == "160Gi" + assert hardware.storage == "160Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + fake_llm_artifact_gateway.model_config = { + "architectures": ["Qwen2ForCausalLM"], + "attention_dropout": 0.0, + "bos_token_id": 151643, + "eos_token_id": 151645, + "hidden_act": "silu", + "hidden_size": 8192, + "initializer_range": 0.02, + "intermediate_size": 29568, + "max_position_embeddings": 32768, + "max_window_layers": 80, + "model_type": "qwen2", + "num_attention_heads": 64, + "num_hidden_layers": 80, + "num_key_value_heads": 8, + "rms_norm_eps": 1e-06, + "rope_theta": 1000000.0, + "sliding_window": 131072, + "tie_word_embeddings": False, + "torch_dtype": "bfloat16", + "transformers_version": "4.40.1", + "use_cache": True, + "use_sliding_window": False, + "vocab_size": 152064, + } + hardware = await _infer_hardware(fake_llm_artifact_gateway, "qwen2-72b-instruct", "") + assert hardware.cpus == 80 + assert hardware.gpus == 4 + assert hardware.memory == "320Gi" + assert hardware.storage == "320Gi" + assert hardware.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert hardware.nodes_per_worker == 1 + + with pytest.raises(ObjectHasInvalidValueException): + await _infer_hardware(fake_llm_artifact_gateway, "unsupported_model", "") + + with pytest.raises(ObjectHasInvalidValueException): + await _infer_hardware(fake_llm_artifact_gateway, "llama-3-999b", "") + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_recommended_hardware_config_map", + mocked__get_recommended_hardware_config_map(), +) +async def test_fill_hardware_info(fake_llm_artifact_gateway): + request = CreateLLMModelEndpointV1Request( + name="mixtral-8x7b", + model_name="mixtral-8x7b", + checkpoint_path="s3://checkpoint", + metadata={}, + min_workers=1, + max_workers=1, + per_worker=1, + labels={}, + ) + await _fill_hardware_info(fake_llm_artifact_gateway, request) + assert request.cpus == 40 + assert request.gpus == 2 + assert request.memory == "160Gi" + assert request.storage == "160Gi" + assert request.gpu_type == GpuType.NVIDIA_HOPPER_H100 + assert request.nodes_per_worker == 1 + + request = CreateLLMModelEndpointV1Request( + name="mixtral-8x7b", + model_name="mixtral-8x7b", + checkpoint_path="s3://checkpoint", + metadata={}, + min_workers=1, + max_workers=1, + per_worker=1, + labels={}, + gpus=1, + ) + + with pytest.raises(ObjectHasInvalidValueException): + await _fill_hardware_info(fake_llm_artifact_gateway, request) + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_recommended_hardware_config_map", + mocked__get_recommended_hardware_config_map(), +) +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_batch_tag", + mocked__get_latest_batch_tag(), +) +async def test_create_batch_completions_v1( + fake_docker_image_batch_job_gateway, + fake_docker_repository_image_always_exists, + fake_docker_image_batch_job_bundle_repository, + fake_llm_artifact_gateway, + test_api_key: str, + create_batch_completions_v1_request: CreateBatchCompletionsV1Request, +): + use_case = CreateBatchCompletionsUseCase( + docker_image_batch_job_gateway=fake_docker_image_batch_job_gateway, + docker_repository=fake_docker_repository_image_always_exists, + docker_image_batch_job_bundle_repo=fake_docker_image_batch_job_bundle_repository, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + result = await use_case.execute(user, create_batch_completions_v1_request) + + job = await fake_docker_image_batch_job_gateway.get_docker_image_batch_job(result.job_id) + assert job.num_workers == create_batch_completions_v1_request.data_parallelism + + bundle = list(fake_docker_image_batch_job_bundle_repository.db.values())[0] + assert bundle.command == [ + "dumb-init", + "--", + "/bin/bash", + "-c", + "ddtrace-run python vllm_batch.py", + ] + + +@pytest.mark.asyncio +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_recommended_hardware_config_map", + mocked__get_recommended_hardware_config_map(), +) +@mock.patch( + "model_engine_server.domain.use_cases.llm_model_endpoint_use_cases._get_latest_batch_v2_tag", + mocked__get_latest_batch_v2_tag(), +) +async def test_create_batch_completions_v2( + fake_llm_batch_completions_service, + fake_llm_artifact_gateway, + test_api_key: str, + create_batch_completions_v2_request: CreateBatchCompletionsV2Request, + create_batch_completions_v2_request_with_hardware: CreateBatchCompletionsV2Request, +): + fake_llm_batch_completions_service.create_batch_job = mock.AsyncMock() + use_case = CreateBatchCompletionsV2UseCase( + llm_batch_completions_service=fake_llm_batch_completions_service, + llm_artifact_gateway=fake_llm_artifact_gateway, + ) + + user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) + await use_case.execute(create_batch_completions_v2_request, user) + + expected_engine_request = CreateBatchCompletionsEngineRequest( + model_cfg=create_batch_completions_v2_request.model_cfg, + max_runtime_sec=create_batch_completions_v2_request.max_runtime_sec, + data_parallelism=create_batch_completions_v2_request.data_parallelism, + labels=create_batch_completions_v2_request.labels, + content=create_batch_completions_v2_request.content, + output_data_path=create_batch_completions_v2_request.output_data_path, + ) + + expected_hardware = CreateDockerImageBatchJobResourceRequests( + cpus=10, + memory="40Gi", + gpus=1, + gpu_type=GpuType.NVIDIA_HOPPER_H100_3G_40GB, + storage="80Gi", + nodes_per_worker=1, + ) + + # assert fake_llm_batch_completions_service was called with the correct arguments + fake_llm_batch_completions_service.create_batch_job.assert_called_with( + user=user, + job_request=expected_engine_request, + image_repo="llm-engine/batch-infer-vllm", + image_tag="fake_docker_repository_latest_image_tag", + resource_requests=expected_hardware, + labels=create_batch_completions_v2_request.labels, + max_runtime_sec=create_batch_completions_v2_request.max_runtime_sec, + num_workers=create_batch_completions_v2_request.data_parallelism, + ) + + await use_case.execute(create_batch_completions_v2_request_with_hardware, user) + + expected_engine_request = CreateBatchCompletionsEngineRequest( + model_cfg=create_batch_completions_v2_request_with_hardware.model_cfg, + max_runtime_sec=create_batch_completions_v2_request_with_hardware.max_runtime_sec, + data_parallelism=create_batch_completions_v2_request_with_hardware.data_parallelism, + labels=create_batch_completions_v2_request_with_hardware.labels, + content=create_batch_completions_v2_request_with_hardware.content, + output_data_path=create_batch_completions_v2_request_with_hardware.output_data_path, + ) + + expected_hardware = CreateDockerImageBatchJobResourceRequests( + cpus=create_batch_completions_v2_request_with_hardware.cpus, + gpus=create_batch_completions_v2_request_with_hardware.gpus, + memory=create_batch_completions_v2_request_with_hardware.memory, + storage=create_batch_completions_v2_request_with_hardware.storage, + gpu_type=create_batch_completions_v2_request_with_hardware.gpu_type, + nodes_per_worker=create_batch_completions_v2_request_with_hardware.nodes_per_worker, + ) + # assert fake_llm_batch_completions_service was called with the correct arguments + fake_llm_batch_completions_service.create_batch_job.assert_called_with( + user=user, + job_request=expected_engine_request, + image_repo="llm-engine/batch-infer-vllm", + image_tag="fake_docker_repository_latest_image_tag", + resource_requests=expected_hardware, + labels=create_batch_completions_v2_request.labels, + max_runtime_sec=create_batch_completions_v2_request.max_runtime_sec, + num_workers=create_batch_completions_v2_request.data_parallelism, + ) + + +def test_merge_metadata(): + request_metadata = { + "key1": "value1", + "key2": "value2", + } + + endpoint_metadata = { + "key1": "value0", + "key3": "value3", + } + + assert merge_metadata(request_metadata, None) == request_metadata + assert merge_metadata(None, endpoint_metadata) == endpoint_metadata + assert merge_metadata(request_metadata, endpoint_metadata) == { + "key1": "value1", + "key2": "value2", + "key3": "value3", + } + + +def test_validate_chat_template(): + assert validate_chat_template(None, LLMInferenceFramework.DEEPSPEED) is None + good_chat_template = CHAT_TEMPLATE_MAX_LENGTH * "_" + assert validate_chat_template(good_chat_template, LLMInferenceFramework.VLLM) is None + + bad_chat_template = (CHAT_TEMPLATE_MAX_LENGTH + 1) * "_" + with pytest.raises(ObjectHasInvalidValueException): + validate_chat_template(bad_chat_template, LLMInferenceFramework.DEEPSPEED) + + with pytest.raises(ObjectHasInvalidValueException): + validate_chat_template(good_chat_template, LLMInferenceFramework.DEEPSPEED) diff --git a/server/tests/unit/domain/test_model_bundle_use_cases.py b/model-engine/tests/unit/domain/test_model_bundle_use_cases.py similarity index 97% rename from server/tests/unit/domain/test_model_bundle_use_cases.py rename to model-engine/tests/unit/domain/test_model_bundle_use_cases.py index 820ffb93..ae2bb7e2 100644 --- a/server/tests/unit/domain/test_model_bundle_use_cases.py +++ b/model-engine/tests/unit/domain/test_model_bundle_use_cases.py @@ -1,5 +1,5 @@ import pytest -from llm_engine_server.common.dtos.model_bundles import ( +from model_engine_server.common.dtos.model_bundles import ( CloneModelBundleV1Request, CreateModelBundleV1Request, CreateModelBundleV1Response, @@ -9,15 +9,15 @@ ModelBundleOrderBy, ModelBundleV1Response, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.exceptions import ( DockerImageNotFoundException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.gateways import ModelPrimitiveGateway -from llm_engine_server.domain.repositories import DockerRepository, ModelBundleRepository -from llm_engine_server.domain.use_cases.model_bundle_use_cases import ( +from model_engine_server.domain.gateways import ModelPrimitiveGateway +from model_engine_server.domain.repositories import DockerRepository, ModelBundleRepository +from model_engine_server.domain.use_cases.model_bundle_use_cases import ( CloneModelBundleV1UseCase, CreateModelBundleV1UseCase, CreateModelBundleV2UseCase, @@ -446,7 +446,7 @@ async def test_create_model_bundle_v2_full_url_use_case_success( model_primitive_gateway=fake_model_primitive_gateway, ) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - # will a full uri specification, image existence is not checked + # will a full uri specification, image existance is not checked create_model_bundle_v2_request.flavor.repository = "registry.hub.docker.com/library/busybox" response = await use_case.execute(user=user, request=create_model_bundle_v2_request) assert response.model_bundle_id diff --git a/server/tests/unit/domain/test_model_endpoint_use_cases.py b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py similarity index 54% rename from server/tests/unit/domain/test_model_endpoint_use_cases.py rename to model-engine/tests/unit/domain/test_model_endpoint_use_cases.py index 20a64885..8fd5cf19 100644 --- a/server/tests/unit/domain/test_model_endpoint_use_cases.py +++ b/model-engine/tests/unit/domain/test_model_endpoint_use_cases.py @@ -1,5 +1,5 @@ import pytest -from llm_engine_server.common.dtos.model_endpoints import ( +from model_engine_server.common.dtos.model_endpoints import ( CreateModelEndpointV1Request, CreateModelEndpointV1Response, DeleteModelEndpointV1Response, @@ -9,29 +9,33 @@ UpdateModelEndpointV1Request, UpdateModelEndpointV1Response, ) -from llm_engine_server.common.resource_limits import ( +from model_engine_server.common.resource_limits import ( FORWARDER_CPU_USAGE, FORWARDER_MEMORY_USAGE, FORWARDER_STORAGE_USAGE, REQUESTS_BY_GPU_TYPE, STORAGE_LIMIT, ) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.entities import ModelBundle, ModelEndpoint +from model_engine_server.domain.exceptions import ( + EndpointBillingTagsMalformedException, + EndpointLabelsException, + EndpointResourceInvalidRequestException, ObjectHasInvalidValueException, ObjectNotAuthorizedException, ObjectNotFoundException, + PostInferenceHooksException, ) -from llm_engine_server.domain.entities import ModelBundle, ModelEndpoint -from llm_engine_server.domain.exceptions import EndpointResourceInvalidRequestException -from llm_engine_server.domain.use_cases.model_endpoint_use_cases import ( +from model_engine_server.domain.use_cases.model_endpoint_use_cases import ( + CONVERTED_FROM_ARTIFACT_LIKE_KEY, CreateModelEndpointV1UseCase, DeleteModelEndpointByIdV1UseCase, GetModelEndpointByIdV1UseCase, ListModelEndpointsV1UseCase, UpdateModelEndpointByIdV1UseCase, ) -from llm_engine_server.infra.gateways.k8s_resource_parser import parse_mem_request +from model_engine_server.infra.gateways.k8s_resource_parser import parse_mem_request @pytest.mark.asyncio @@ -65,6 +69,38 @@ async def test_create_model_endpoint_use_case_success( assert response_3.endpoint_creation_task_id assert isinstance(response_3, CreateModelEndpointV1Response) + # test special case where sync/streaming endpoint that has 0-1 min-max workers works + fake_model_endpoint_service.set_can_scale_http_endpoint_from_zero_flag(True) + request = create_model_endpoint_request_sync.copy() + request.min_workers = 0 + request.max_workers = 1 + response_4 = await use_case.execute(user=user, request=request) + assert response_4.endpoint_creation_task_id + assert isinstance(response_4, CreateModelEndpointV1Response) + + request = create_model_endpoint_request_streaming.copy() + request.min_workers = 0 + request.max_workers = 1 + response_5 = await use_case.execute(user=user, request=request) + assert response_5.endpoint_creation_task_id + assert isinstance(response_5, CreateModelEndpointV1Response) + + # test general case as well for 0-N + request = create_model_endpoint_request_sync.copy() + request.min_workers = 0 + request.max_workers = 5 + response_6 = await use_case.execute(user=user, request=request) + assert response_6.endpoint_creation_task_id + assert isinstance(response_6, CreateModelEndpointV1Response) + + # test you can ask for more storage on H100s + request = create_model_endpoint_request_sync.copy() + request.storage = "950Gi" + request.gpu_type = "nvidia-hopper-h100" + response_7 = await use_case.execute(user=user, request=request) + assert response_7.endpoint_creation_task_id + assert isinstance(response_7, CreateModelEndpointV1Response) + @pytest.mark.asyncio async def test_create_model_endpoint_use_case_raises_invalid_value_exception( @@ -165,10 +201,12 @@ async def test_create_model_endpoint_use_case_raises_resource_request_exception( with pytest.raises(EndpointResourceInvalidRequestException): await use_case.execute(user=user, request=request) + fake_model_endpoint_service.set_can_scale_http_endpoint_from_zero_flag(False) request = create_model_endpoint_request_sync.copy() request.min_workers = 0 with pytest.raises(EndpointResourceInvalidRequestException): await use_case.execute(user=user, request=request) + fake_model_endpoint_service.set_can_scale_http_endpoint_from_zero_flag(True) request = create_model_endpoint_request_async.copy() request.max_workers = 2**63 @@ -279,6 +317,154 @@ async def test_create_model_endpoint_use_case_raises_resource_request_exception( await use_case.execute(user=user, request=request) +@pytest.mark.asyncio +async def test_create_model_endpoint_use_case_raises_endpoint_labels_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + create_model_endpoint_request_async: CreateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = CreateModelEndpointV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_bundle_1.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + request = create_model_endpoint_request_async.copy() + request.labels = None # type: ignore + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.labels = {} + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.labels = {"team": "infra"} + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.labels = {"product": "my_product"} + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.labels = { + "team": "infra", + "product": "my_product", + "user_id": "test_labels_user", + } + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.labels = { + "team": "infra", + "product": "my_product", + "endpoint_name": "test_labels_endpoint_name", + } + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + +@pytest.mark.skip(reason="TODO: team validation is currently disabled") +@pytest.mark.asyncio +async def test_create_model_endpoint_use_case_invalid_team_raises_endpoint_labels_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + create_model_endpoint_request_async: CreateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = CreateModelEndpointV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_bundle_1.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + request = create_model_endpoint_request_async.copy() + request.labels = { + "team": "unknown_team", + "product": "my_product", + } + with pytest.raises(EndpointLabelsException): + await use_case.execute(user=user, request=request) + + # for team in ALLOWED_TEAMS: + # # Conversely, make sure that all the ALLOWED_TEAMS are, well, allowed. + # request = create_model_endpoint_request_async.copy() + # request.labels = { + # "team": team, + # "product": "my_product", + # } + # await use_case.execute(user=user, request=request) + + +@pytest.mark.asyncio +async def test_create_model_endpoint_use_case_raises_billing_tags_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + create_model_endpoint_request_async: CreateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = CreateModelEndpointV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_bundle_1.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + request = create_model_endpoint_request_async.copy() + request.billing_tags = None + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.billing_tags = { + "idempotencyKeyPrefix": "val1", + "product": "val2", + "type": "val3", + "subType": "val4", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "val5", + "payor": "val6", + "reference": {"referenceType": "val7", "referenceId": "val8"}, + } + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.billing_tags = {"incomplete_labels": "hi"} + with pytest.raises(EndpointBillingTagsMalformedException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.billing_tags = { + "idempotencyKeyPrefix": ["wrong", "type"], + "product": "val2", + "type": "val3", + "subType": "val4", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "val5", + "payor": "val6", + "reference": {"referenceType": "val7", "referenceId": "val8"}, + } + with pytest.raises(EndpointBillingTagsMalformedException): + await use_case.execute(user=user, request=request) + + request = create_model_endpoint_request_async.copy() + request.billing_tags = "not_a_dict" # type: ignore + with pytest.raises(EndpointBillingTagsMalformedException): + await use_case.execute(user=user, request=request) + + @pytest.mark.asyncio async def test_create_model_endpoint_use_case_validates_post_inference_hooks( fake_model_bundle_repository, @@ -297,7 +483,7 @@ async def test_create_model_endpoint_use_case_validates_post_inference_hooks( request = create_model_endpoint_request_async.copy() request.post_inference_hooks = ["invalid_hook"] - with pytest.raises(ValueError): + with pytest.raises(PostInferenceHooksException): await use_case.execute(user=user, request=request) @@ -451,19 +637,77 @@ async def test_create_model_endpoint_use_case_sets_high_priority( await fake_model_endpoint_service.delete_model_endpoint(endpoints[0].record.id) +@pytest.mark.asyncio +async def test_create_multinode_endpoint_with_nonmultinode_bundle_fails( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + create_model_endpoint_request_streaming: CreateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = CreateModelEndpointV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_bundle_1.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + create_model_endpoint_request_streaming.nodes_per_worker = 2 + create_model_endpoint_request_streaming.model_bundle_id = model_bundle_1.id + with pytest.raises(ObjectHasInvalidValueException): + await use_case.execute(user=user, request=create_model_endpoint_request_streaming) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("nodes_per_worker", [1, 2]) +async def test_create_multinode_or_nonmultinode_endpoint_with_multinode_bundle_succeeds( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_5: ModelBundle, + create_model_endpoint_request_streaming: CreateModelEndpointV1Request, + nodes_per_worker: int, +): + # mb5 is a streaming runnable image bundle + model_bundle_5.flavor.worker_env = {"fake_env": "fake_value"} + model_bundle_5.flavor.worker_command = ["fake_command"] + fake_model_bundle_repository.add_model_bundle(model_bundle_5) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = CreateModelEndpointV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_bundle_5.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + create_model_endpoint_request_streaming.nodes_per_worker = nodes_per_worker + create_model_endpoint_request_streaming.model_bundle_id = model_bundle_5.id + response = await use_case.execute(user=user, request=create_model_endpoint_request_streaming) + assert response.endpoint_creation_task_id + assert isinstance(response, CreateModelEndpointV1Response) + + @pytest.mark.asyncio async def test_get_model_endpoint_use_case_success( test_api_key: str, fake_model_endpoint_service, model_endpoint_1: ModelEndpoint, + model_endpoint_2: ModelEndpoint, ): + # Tests single node + multinode fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + model_endpoint_2.infra_state.resource_state.nodes_per_worker = 2 + fake_model_endpoint_service.add_model_endpoint(model_endpoint_2) use_case = GetModelEndpointByIdV1UseCase(model_endpoint_service=fake_model_endpoint_service) user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) response = await use_case.execute(user=user, model_endpoint_id=model_endpoint_1.record.id) assert isinstance(response, GetModelEndpointV1Response) + response_2 = await use_case.execute(user=user, model_endpoint_id=model_endpoint_2.record.id) + assert isinstance(response_2, GetModelEndpointV1Response) + assert response_2.resource_state.nodes_per_worker == 2 + @pytest.mark.asyncio async def test_get_model_endpoint_use_case_same_team_finds_endpoint( @@ -690,6 +934,305 @@ async def test_update_model_endpoint_team_success( assert isinstance(response, UpdateModelEndpointV1Response) +@pytest.mark.asyncio +async def test_update_model_endpoint_use_case_raises_invalid_value_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_2: ModelBundle, + model_endpoint_1: ModelEndpoint, + update_model_endpoint_request: UpdateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_2) + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = UpdateModelEndpointByIdV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_endpoint_1.record.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + request = update_model_endpoint_request.copy() + request.metadata = {CONVERTED_FROM_ARTIFACT_LIKE_KEY: False} + with pytest.raises(ObjectHasInvalidValueException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + +@pytest.mark.asyncio +async def test_update_model_endpoint_use_case_raises_resource_request_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + model_bundle_2: ModelBundle, + model_bundle_4: ModelBundle, + model_bundle_6: ModelBundle, + model_bundle_triton_enhanced_runnable_image_0_cpu_None_memory_storage: ModelBundle, + model_endpoint_1: ModelEndpoint, + model_endpoint_2: ModelEndpoint, + update_model_endpoint_request: UpdateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_bundle_repository.add_model_bundle(model_bundle_2) + fake_model_bundle_repository.add_model_bundle(model_bundle_4) + fake_model_bundle_repository.add_model_bundle(model_bundle_6) + fake_model_bundle_repository.add_model_bundle( + model_bundle_triton_enhanced_runnable_image_0_cpu_None_memory_storage + ) + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_model_endpoint_service.add_model_endpoint(model_endpoint_2) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = UpdateModelEndpointByIdV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + user_id = model_endpoint_1.record.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + + request = update_model_endpoint_request.copy() + request.cpus = -1 + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.cpus = float("inf") + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.memory = "invalid_memory_amount" + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.memory = float("inf") + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.storage = "invalid_storage_amount" + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.storage = float("inf") + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + # specific to sync endpoint + fake_model_endpoint_service.set_can_scale_http_endpoint_from_zero_flag(False) + request = update_model_endpoint_request.copy() + request.min_workers = 0 + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_2.record.id, + request=request, + ) + fake_model_endpoint_service.set_can_scale_http_endpoint_from_zero_flag(True) + + request = update_model_endpoint_request.copy() + request.max_workers = 2**63 + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.gpus = 0 + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.gpu_type = None + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.gpu_type = "invalid_gpu_type" + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + instance_limits = REQUESTS_BY_GPU_TYPE[model_endpoint_1.infra_state.resource_state.gpu_type] + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_1.id + # Test that request.cpus + FORWARDER_CPU_USAGE > instance_limits["cpus"] should fail + request.cpus = instance_limits["cpus"] + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_1.id + # Test that request.memory + FORWARDER_MEMORY_USAGE > instance_limits["memory"] should fail + request.memory = instance_limits["memory"] + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_1.id + # Test that request.storage + FORWARDER_STORAGE_USAGE > STORAGE_LIMIT should fail + request.storage = STORAGE_LIMIT + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_4.id + # Test that request.cpus + FORWARDER_CPU_USAGE > instance_limits["cpus"] should fail + request.cpus = instance_limits["cpus"] + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_4.id + # Test that request.memory + FORWARDER_MEMORY_USAGE > instance_limits["memory"] should fail + request.memory = instance_limits["memory"] + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_4.id + # Test that request.storage + FORWARDER_STORAGE_USAGE > STORAGE_LIMIT should fail + request.storage = STORAGE_LIMIT + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + # Test TritonEnhancedRunnableImageFlavor specific validation logic + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_6.id + # TritonEnhancedRunnableImageFlavor requires gpu >= 1 + request.gpus = 0.9 + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_6.id + # TritonEnhancedRunnableImageFlavor requires gpu_type be specified + request.gpu_type = None + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_6.id + # Test that request.cpus + FORWARDER_CPU_USAGE + triton_num_cpu > instance_limits["cpu"] should fail + request.cpus = instance_limits["cpus"] - FORWARDER_CPU_USAGE + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_6.id + # Test that request.memory + FORWARDER_MEMORY_USAGE + triton_memory > instance_limits["memory"] should fail + request.memory = parse_mem_request(instance_limits["memory"]) - parse_mem_request( + FORWARDER_MEMORY_USAGE + ) + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.model_bundle_id = model_bundle_6.id + # Test that request.storage + FORWARDER_STORAGE_USAGE + triton_storage > STORAGE_LIMIT should fail + request.storage = parse_mem_request(STORAGE_LIMIT) - parse_mem_request(FORWARDER_STORAGE_USAGE) + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + # Test triton_num_cpu >= 1 + request.model_bundle_id = ( + model_bundle_triton_enhanced_runnable_image_0_cpu_None_memory_storage.id + ) + with pytest.raises(EndpointResourceInvalidRequestException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + @pytest.mark.asyncio async def test_update_model_endpoint_raises_not_found( fake_model_bundle_repository, @@ -768,6 +1311,201 @@ async def test_update_model_endpoint_raises_not_authorized( ) +@pytest.mark.asyncio +async def test_update_model_endpoint_raises_endpoint_labels_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + model_bundle_2: ModelBundle, + model_endpoint_1: ModelEndpoint, + update_model_endpoint_request: UpdateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_bundle_repository.add_model_bundle(model_bundle_2) + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = UpdateModelEndpointByIdV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + + request = update_model_endpoint_request.copy() + request.labels = {"team": "infra"} + user_id = model_endpoint_1.record.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + with pytest.raises(EndpointLabelsException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.labels = {"product": "my_product"} + with pytest.raises(EndpointLabelsException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.labels = { + "team": "infra", + "product": "my_product", + "user_id": "test_labels_user", + } + with pytest.raises(EndpointLabelsException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.labels = { + "team": "infra", + "product": "my_product", + "endpoint_name": "test_labels_endpoint_name", + } + with pytest.raises(EndpointLabelsException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + +@pytest.mark.skip(reason="TODO: team validation is currently disabled") +@pytest.mark.asyncio +async def test_update_model_endpoint_invalid_team_raises_endpoint_labels_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + model_bundle_2: ModelBundle, + model_endpoint_1: ModelEndpoint, + update_model_endpoint_request: UpdateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_bundle_repository.add_model_bundle(model_bundle_2) + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = UpdateModelEndpointByIdV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + + request = update_model_endpoint_request.copy() + request.labels = { + "team": "invalid_team", + "product": "some_product", + } + user_id = model_endpoint_1.record.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + with pytest.raises(EndpointLabelsException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + # TODO: renable this part of the test once we figure out how to import this + # properly + + # for team in ALLOWED_TEAMS: + # # Conversely, make sure that all the ALLOWED_TEAMS are, well, allowed. + # request = update_model_endpoint_request.copy() + # request.labels = { + # "team": team, + # "product": "my_product", + # } + # await use_case.execute( + # user=user, model_endpoint_id=model_endpoint_1.record.id, request=request + # ) + + +@pytest.mark.asyncio +async def test_update_model_endpoint_raises_billing_tags_exception( + fake_model_bundle_repository, + fake_model_endpoint_service, + model_bundle_1: ModelBundle, + model_bundle_2: ModelBundle, + model_endpoint_1: ModelEndpoint, + update_model_endpoint_request: UpdateModelEndpointV1Request, +): + fake_model_bundle_repository.add_model_bundle(model_bundle_1) + fake_model_bundle_repository.add_model_bundle(model_bundle_2) + fake_model_endpoint_service.add_model_endpoint(model_endpoint_1) + fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository + use_case = UpdateModelEndpointByIdV1UseCase( + model_bundle_repository=fake_model_bundle_repository, + model_endpoint_service=fake_model_endpoint_service, + ) + + request = update_model_endpoint_request.copy() + request.billing_tags = None + user_id = model_endpoint_1.record.created_by + user = User(user_id=user_id, team_id=user_id, is_privileged_user=True) + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.billing_tags = { + "idempotencyKeyPrefix": "val1", + "product": "val2", + "type": "val3", + "subType": "val4", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "val5", + "payor": "val6", + "reference": {"referenceType": "val7", "referenceId": "val8"}, + } + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.billing_tags = {"incomplete_labels": "hi"} + with pytest.raises(EndpointBillingTagsMalformedException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.billing_tags = { + "idempotencyKeyPrefix": ["wrong", "type"], + "product": "val2", + "type": "val3", + "subType": "val4", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "val5", + "payor": "val6", + "reference": {"referenceType": "val7", "referenceId": "val8"}, + } + with pytest.raises(EndpointBillingTagsMalformedException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + request = update_model_endpoint_request.copy() + request.billing_tags = "not_a_dict" # type: ignore + with pytest.raises(EndpointBillingTagsMalformedException): + await use_case.execute( + user=user, + model_endpoint_id=model_endpoint_1.record.id, + request=request, + ) + + @pytest.mark.asyncio async def test_delete_model_endpoint_success( fake_model_endpoint_service, diff --git a/server/tests/unit/domain/test_streaming_inference_use_cases.py b/model-engine/tests/unit/domain/test_streaming_inference_use_cases.py similarity index 88% rename from server/tests/unit/domain/test_streaming_inference_use_cases.py rename to model-engine/tests/unit/domain/test_streaming_inference_use_cases.py index 5043b6a4..9da48267 100644 --- a/server/tests/unit/domain/test_streaming_inference_use_cases.py +++ b/model-engine/tests/unit/domain/test_streaming_inference_use_cases.py @@ -1,15 +1,15 @@ from typing import Any, Dict, Tuple import pytest -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.domain.exceptions import ( + EndpointUnsupportedInferenceTypeException, ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.entities import ModelEndpoint -from llm_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException -from llm_engine_server.domain.use_cases.streaming_inference_use_cases import ( +from model_engine_server.domain.use_cases.streaming_inference_use_cases import ( CreateStreamingInferenceTaskV1UseCase, ) diff --git a/server/tests/unit/domain/test_sync_inference_use_cases.py b/model-engine/tests/unit/domain/test_sync_inference_use_cases.py similarity index 90% rename from server/tests/unit/domain/test_sync_inference_use_cases.py rename to model-engine/tests/unit/domain/test_sync_inference_use_cases.py index ffb3637e..673cafa1 100644 --- a/server/tests/unit/domain/test_sync_inference_use_cases.py +++ b/model-engine/tests/unit/domain/test_sync_inference_use_cases.py @@ -1,14 +1,14 @@ from typing import Any, Dict, Tuple import pytest -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.domain.exceptions import ( ObjectNotAuthorizedException, ObjectNotFoundException, ) -from llm_engine_server.domain.entities import ModelEndpoint -from llm_engine_server.domain.use_cases.sync_inference_use_cases import ( +from model_engine_server.domain.use_cases.sync_inference_use_cases import ( CreateSyncInferenceTaskV1UseCase, ) diff --git a/model-engine/tests/unit/inference/conftest.py b/model-engine/tests/unit/inference/conftest.py new file mode 100644 index 00000000..e07a7c73 --- /dev/null +++ b/model-engine/tests/unit/inference/conftest.py @@ -0,0 +1,213 @@ +from unittest.mock import MagicMock + +import pytest +from model_engine_server.inference.batch_inference.dto import ( + CompletionOutput, + CreateBatchCompletionsEngineRequest, + CreateBatchCompletionsModelConfig, + CreateBatchCompletionsRequestContent, + TokenOutput, + ToolConfig, +) + + +@pytest.fixture +def anyio_backend(): + return "asyncio" + + +@pytest.fixture +def create_batch_completions_engine_request() -> CreateBatchCompletionsEngineRequest: + model_config = CreateBatchCompletionsModelConfig( + model="model", + checkpoint_path="checkpoint_path", + labels={}, + seed=123, + num_shards=4, + ) + return CreateBatchCompletionsEngineRequest( + input_data_path="input_data_path", + output_data_path="output_data_path", + model_cfg=model_config, + model_config=model_config, + data_parallelism=1, + max_runtime_sec=86400, + max_gpu_memory_utilization=0.95, + ) + + +@pytest.fixture +def create_batch_completions_tool_completion_request(): + model_config = CreateBatchCompletionsModelConfig( + checkpoint_path="checkpoint_path", + model="model", + num_shards=4, + seed=123, + labels={}, + ) + + return CreateBatchCompletionsEngineRequest( + model_cfg=model_config, + model_config=model_config, + data_parallelism=1, + input_data_path="input_data_path", + output_data_path="output_data_path", + tool_config=ToolConfig(name="code_evaluator"), + ) + + +@pytest.fixture +def create_batch_completions_tool_completion_request_content(): + return CreateBatchCompletionsRequestContent( + prompts=["prompt1"], + max_new_tokens=100, + temperature=0.8, + return_token_log_probs=True, + ) + + +@pytest.fixture +def create_batch_completions_request_content(): + return CreateBatchCompletionsRequestContent( + prompts=["prompt1", "prompt2"], + max_new_tokens=100, + temperature=0.8, + return_token_log_probs=True, + ) + + +@pytest.fixture +def create_vllm_request_outputs(): + class Logprob: + """mock, from https://github.com/vllm-project/vllm/blob/v0.4.1/vllm/sequence.py#L18""" + + def __init__(self, logprob: float): + self.logprob = logprob + + mock_vllm_request_output1 = MagicMock() + mock_vllm_request_output1.outputs = [ + MagicMock(text="text1"), + ] + mock_vllm_request_output1.prompt_token_ids = [1, 2, 3] + mock_vllm_request_output1.outputs[0].token_ids = [4] + mock_vllm_request_output1.outputs[0].logprobs = [{4: Logprob(0.1)}] + + mock_vllm_request_output2 = MagicMock() + mock_vllm_request_output2.outputs = [ + MagicMock(text="text1 text2"), + ] + mock_vllm_request_output2.prompt_token_ids = [1, 2, 3] + mock_vllm_request_output2.outputs[0].token_ids = [4, 5] + mock_vllm_request_output2.outputs[0].logprobs = [{4: Logprob(0.1), 5: Logprob(0.2)}] + + mock_vllm_request_output3 = MagicMock() + mock_vllm_request_output3.outputs = [ + MagicMock(text="text1 text2 text3"), + ] + mock_vllm_request_output3.prompt_token_ids = [1, 2, 3] + mock_vllm_request_output3.outputs[0].token_ids = [4, 5, 6] + mock_vllm_request_output3.outputs[0].logprobs = [ + {4: Logprob(0.1), 5: Logprob(0.2), 6: Logprob(0.3)} + ] + return [ + mock_vllm_request_output1, + mock_vllm_request_output2, + mock_vllm_request_output3, + ] + + +@pytest.fixture +def mock_s3_client(): + mock_s3_client = MagicMock() + mock_s3_client.delete_object.return_value = None + return mock_s3_client + + +@pytest.fixture +def mock_process(): + mock_process = MagicMock() + mock_process.stdout = [] + mock_process.stderr.readline.side_effect = [ + "error", + ] + mock_process.returncode = 0 + mock_process.wait.return_value = None + return mock_process + + +@pytest.fixture +def mock_completion_output(): + return CompletionOutput( + text="text1 text2 text3", + num_prompt_tokens=3, + num_completion_tokens=3, + tokens=[ + TokenOutput(token="text1", log_prob=0.1), + TokenOutput(token=" text2", log_prob=0.2), + TokenOutput(token=" text3", log_prob=0.3), + ], + ) + + +@pytest.fixture +def mock_tool_completion_output(): + return CompletionOutput( + text="```python\nimport math\nprint(math.sqrt(2))\n```\n1.414...\n", + num_prompt_tokens=10, + num_completion_tokens=28, + tokens=[ + TokenOutput(token="``", log_prob=-0.1980377733707428), + TokenOutput(token="`", log_prob=-0.0037908137310296297), + TokenOutput(token="python", log_prob=-0.015637163072824478), + TokenOutput(token="\n", log_prob=-0.0010788579238578677), + TokenOutput(token="import", log_prob=-0.04351021721959114), + TokenOutput(token=" math", log_prob=-0.0021214615553617477), + TokenOutput(token="\n", log_prob=-0.002169043058529496), + TokenOutput(token="print", log_prob=-0.06555093079805374), + TokenOutput(token="(", log_prob=-0.005272886715829372), + TokenOutput(token="math", log_prob=-0.009995171800255775), + TokenOutput(token=".", log_prob=-0.0002040654799202457), + TokenOutput(token="sqrt", log_prob=-0.00886327400803566), + TokenOutput(token="(", log_prob=-0.0015410225605592132), + TokenOutput(token="2", log_prob=-0.008573509752750397), + TokenOutput(token="))", log_prob=-0.010970987379550934), + TokenOutput(token="\n", log_prob=-0.002175347413867712), + TokenOutput(token="``", log_prob=-0.01911235973238945), + TokenOutput(token="`", log_prob=-0.0005327236140146852), + TokenOutput(token="\n", log_prob=-0.002304519060999155), + TokenOutput(token="1", log_prob=-0.10852570831775665), + TokenOutput(token=".", log_prob=-0.007146273739635944), + TokenOutput(token="4", log_prob=-0.003810290014371276), + TokenOutput(token="1", log_prob=-0.002774677239358425), + TokenOutput(token="4", log_prob=-0.16946221888065338), + TokenOutput(token=".", log_prob=-0.007678280584514141), + TokenOutput(token=".", log_prob=-0.021146666258573532), + TokenOutput(token=".", log_prob=-0.3870151937007904), + TokenOutput(token="\n", log_prob=-0.027081478387117386), + ], + ) + + +@pytest.fixture +def mock_tool_completion_output2(): + return CompletionOutput( + text="Final Answer: 4\n", + num_prompt_tokens=38, + num_completion_tokens=6, + tokens=[ + TokenOutput(token="Final", log_prob=-0.1980377733707428), + TokenOutput(token=" Answer", log_prob=-0.0037908137310296297), + TokenOutput(token=":", log_prob=-0.015637163072824478), + TokenOutput(token=" ", log_prob=-0.0010788579238578677), + TokenOutput(token="4", log_prob=-0.04351021721959114), + TokenOutput(token="\n", log_prob=-0.0021214615553617477), + ], + ) + + +@pytest.fixture +def mock_run_output(): + value = MagicMock() + value.stdout = "1.4142135623730951" + value.check_returncode = MagicMock() + return value diff --git a/model-engine/tests/unit/inference/test_forwarding.py b/model-engine/tests/unit/inference/test_forwarding.py new file mode 100644 index 00000000..0462b317 --- /dev/null +++ b/model-engine/tests/unit/inference/test_forwarding.py @@ -0,0 +1,451 @@ +import json +from dataclasses import dataclass +from typing import Mapping +from unittest import mock + +import pytest +from fastapi import HTTPException +from fastapi.responses import JSONResponse +from model_engine_server.core.utils.env import environment +from model_engine_server.domain.entities import ModelEndpointConfig +from model_engine_server.inference.forwarding.forwarding import ( + ENV_SERIALIZE_RESULTS_AS_STRING, + KEY_SERIALIZE_RESULTS_AS_STRING, + Forwarder, + LoadForwarder, + LoadStreamingForwarder, + StreamingForwarder, + load_named_config, +) +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( + DatadogInferenceMonitoringMetricsGateway, +) +from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler +from tests.unit.conftest import FakeStreamingStorageGateway + +PAYLOAD: Mapping[str, str] = {"hello": "world"} +PAYLOAD_END = "[DONE]" + + +def mocked_get(*args, **kwargs): # noqa + @dataclass + class mocked_static_status_code: + status_code: int = 200 + + return mocked_static_status_code() + + +def mocked_post(*args, **kwargs): # noqa + @dataclass + class mocked_static_json: + status_code: int = 200 + + def json(self) -> dict: + return PAYLOAD # type: ignore + + return mocked_static_json() + + +def mocked_post_400(*args, **kwargs): # noqa + @dataclass + class mocked_static_json: + status_code: int = 400 + + def json(self) -> dict: + return PAYLOAD # type: ignore + + return mocked_static_json() + + +def mocked_post_500(*args, **kwargs): # noqa + @dataclass + class mocked_static_json: + status_code: int = 500 + + def json(self) -> dict: + return PAYLOAD # type: ignore + + return mocked_static_json() + + +def mocked_sse_client(*args, **kwargs): # noqa + @dataclass + class Event: + data: str + + @dataclass + class mocked_static_events: + def events(self) -> list: + payload_json = json.dumps(PAYLOAD) + return [ + Event(data=payload_json), + Event(data=payload_json), + Event(data=PAYLOAD_END), + ] + + return mocked_static_events() + + +def mocked_get_endpoint_config(): + return ModelEndpointConfig( + endpoint_name="test_endpoint_name", + bundle_name="test_bundle_name", + ) + + +@pytest.fixture +def post_inference_hooks_handler(): + handler = PostInferenceHooksHandler( + endpoint_name="test_endpoint_name", + bundle_name="test_bundle_name", + post_inference_hooks=[], + user_id="test_user_id", + billing_queue="billing_queue", + billing_tags=[], + default_callback_url=None, + default_callback_auth=None, + monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id="test_endpoint_id", + endpoint_type="sync", + bundle_id="test_bundle_id", + labels={}, + streaming_storage_gateway=FakeStreamingStorageGateway(), + ) + return handler + + +def mocked_config_content(): + return { + "forwarder": { + "sync": { + "user_port": 5005, + "user_hostname": "localhost", + "use_grpc": False, + "predict_route": "/predict", + "healthcheck_route": "/readyz", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": True, + "forward_http_status": True, + }, + "stream": { + "user_port": 5005, + "user_hostname": "localhost", + "predict_route": "/stream", + "healthcheck_route": "/readyz", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": False, + }, + "max_concurrency": 42, + } + } + + +def mocked_config_overrides(): + return [ + "forwarder.sync.extra_routes=['/v1/chat/completions']", + "forwarder.stream.extra_routes=['/v1/chat/completions']", + "forwarder.sync.healthcheck_route=/health", + "forwarder.stream.healthcheck_route=/health", + ] + + +# patch open(config_uri, "rt") and have output be mocked_config_content +@mock.patch("builtins.open", mock.mock_open(read_data=json.dumps(mocked_config_content()))) +def test_load_named_config(): + output = load_named_config("dummy.yml", config_overrides=mocked_config_overrides()) + expected_output = { + "name": "forwarder", + "sync": { + "user_port": 5005, + "user_hostname": "localhost", + "use_grpc": False, + "predict_route": "/predict", + "healthcheck_route": "/health", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": True, + "forward_http_status": True, + "extra_routes": ["/v1/chat/completions"], + }, + "stream": { + "user_port": 5005, + "user_hostname": "localhost", + "predict_route": "/stream", + "healthcheck_route": "/health", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": False, + "extra_routes": ["/v1/chat/completions"], + }, + "max_concurrency": 42, + } + assert output == expected_output + + +@mock.patch("requests.post", mocked_post) +@mock.patch("requests.get", mocked_get) +def test_forwarders(post_inference_hooks_handler): + fwd = Forwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=False, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=True, + forward_http_status=True, + ) + json_response = fwd({"ignore": "me"}) + _check(json_response) + + +def _check(json_response) -> None: + json_response = ( + json.loads(json_response.body.decode("utf-8")) + if isinstance(json_response, JSONResponse) + else json_response + ) + assert json_response == {"result": PAYLOAD} + + +def _check_responses_not_wrapped(json_response) -> None: + json_response = ( + json.loads(json_response.body.decode("utf-8")) + if isinstance(json_response, JSONResponse) + else json_response + ) + assert json_response == PAYLOAD + + +def _check_streaming(streaming_response) -> None: + streaming_response_list = list(streaming_response) + assert len(streaming_response_list) == 3 + assert streaming_response_list[0] == {"result": PAYLOAD} + assert streaming_response_list[1] == {"result": PAYLOAD} + assert streaming_response_list[2] == {"result": PAYLOAD_END} + + +def _check_streaming_serialized(streaming_response) -> None: + streaming_response_list = list(streaming_response) + assert len(streaming_response_list) == 3 + assert streaming_response_list[0] == {"result": json.dumps(PAYLOAD)} + assert streaming_response_list[1] == {"result": json.dumps(PAYLOAD)} + assert streaming_response_list[2] == {"result": PAYLOAD_END} + + +@mock.patch("requests.post", mocked_post) +@mock.patch("requests.get", mocked_get) +def test_forwarders_serialize_results_as_string(post_inference_hooks_handler): + fwd = Forwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=True, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=True, + forward_http_status=True, + ) + json_response = fwd({"ignore": "me"}) + _check_serialized(json_response) + + +def _check_serialized(json_response) -> None: + json_response = ( + json.loads(json_response.body.decode("utf-8")) + if isinstance(json_response, JSONResponse) + else json_response + ) + assert isinstance(json_response["result"], str) + assert len(json_response) == 1, f"expecting only 'result' key, but got {json_response=}" + assert json.loads(json_response["result"]) == PAYLOAD + + +@mock.patch("requests.post", mocked_post) +@mock.patch("requests.get", mocked_get) +def test_forwarders_override_serialize_results(post_inference_hooks_handler): + fwd = Forwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=True, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=True, + forward_http_status=True, + ) + json_response = fwd({"ignore": "me", KEY_SERIALIZE_RESULTS_AS_STRING: False}) + _check(json_response) + + fwd = Forwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=False, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=True, + forward_http_status=True, + ) + json_response = fwd({"ignore": "me", KEY_SERIALIZE_RESULTS_AS_STRING: True}) + _check_serialized(json_response) + + +@mock.patch("requests.post", mocked_post) +@mock.patch("requests.get", mocked_get) +def test_forwarder_does_not_wrap_response(post_inference_hooks_handler): + fwd = Forwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=False, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=False, + forward_http_status=True, + ) + json_response = fwd({"ignore": "me"}) + _check_responses_not_wrapped(json_response) + + +@mock.patch("requests.post", mocked_post_500) +@mock.patch("requests.get", mocked_get) +def test_forwarder_return_status_code(post_inference_hooks_handler): + fwd = Forwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=True, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=False, + forward_http_status=True, + ) + json_response = fwd({"ignore": "me"}) + _check_responses_not_wrapped(json_response) + assert json_response.status_code == 500 + + +@mock.patch("requests.post", mocked_post_500) +@mock.patch("requests.get", mocked_get) +def test_forwarder_dont_return_status_code(post_inference_hooks_handler): + fwd = Forwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=True, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=False, + forward_http_status=False, + ) + json_response = fwd({"ignore": "me"}) + assert json_response == PAYLOAD + + +@mock.patch("requests.post", mocked_post) +@mock.patch("requests.get", mocked_get) +@mock.patch( + "model_engine_server.inference.forwarding.forwarding.get_endpoint_config", + mocked_get_endpoint_config, +) +def test_forwarder_loader(): + fwd = LoadForwarder(serialize_results_as_string=True).load(None, None) # type: ignore + json_response = fwd({"ignore": "me"}) + _check_serialized(json_response) + + fwd = LoadForwarder(serialize_results_as_string=False).load(None, None) # type: ignore + json_response = fwd({"ignore": "me"}) + _check(json_response) + + fwd = LoadForwarder(wrap_response=False).load(None, None) # type: ignore + json_response = fwd({"ignore": "me"}) + _check_responses_not_wrapped(json_response) + + +@mock.patch("requests.post", mocked_post) +@mock.patch("requests.get", mocked_get) +@mock.patch( + "model_engine_server.inference.forwarding.forwarding.get_endpoint_config", + mocked_get_endpoint_config, +) +def test_forwarder_loader_env_serialize_behavior(post_inference_hooks_handler): + with environment(**{ENV_SERIALIZE_RESULTS_AS_STRING: "false"}): + fwd = LoadForwarder(serialize_results_as_string=True).load(None, None) # type: ignore + json_response = fwd({"ignore": "me"}) + _check(json_response) + + with environment(**{ENV_SERIALIZE_RESULTS_AS_STRING: "true"}): + fwd = LoadForwarder(serialize_results_as_string=False).load(None, None) # type: ignore + json_response = fwd({"ignore": "me"}) + _check_serialized(json_response) + + +@mock.patch("requests.post", mocked_post) +@mock.patch("requests.get", mocked_get) +def test_forwarder_serialize_within_args(post_inference_hooks_handler): + # standard Launch-created forwarder + fwd = Forwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=True, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=True, + forward_http_status=True, + ) + # expected: no `serialize_results_as_string` at top-level nor in 'args' + json_response = fwd({"something": "to ignore", "args": {"my": "payload", "is": "here"}}) + _check_serialized(json_response) + # unwraps under "args" to find `serialize_results_as_string` + payload = { + "something": "to ignore", + "args": {"my": "payload", "is": "here", "serialize_results_as_string": False}, + } + json_response = fwd(payload) + _check(json_response) + # w/o unwrapping it won't "find" the `"serialize_results_as_string": False` directive + fwd = Forwarder( + "ignored", + model_engine_unwrap=False, + serialize_results_as_string=True, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=True, + forward_http_status=True, + ) + json_response = fwd(payload) + _check_serialized(json_response) + + +@mock.patch("requests.post", mocked_post) +@mock.patch("requests.get", mocked_get) +@mock.patch("sseclient.SSEClient", mocked_sse_client) +def test_streaming_forwarders(post_inference_hooks_handler): + fwd = StreamingForwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=False, + post_inference_hooks_handler=post_inference_hooks_handler, + ) + response = fwd({"ignore": "me"}) + _check_streaming(response) + + +@mock.patch("requests.post", mocked_post_400) +@mock.patch("requests.get", mocked_get) +@mock.patch("sseclient.SSEClient", mocked_sse_client) +def test_streaming_forwarder_400_upstream(post_inference_hooks_handler): + fwd = StreamingForwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=False, + post_inference_hooks_handler=post_inference_hooks_handler, + ) + with pytest.raises(HTTPException) as e: + fwd({"ignore": "me"}) + + assert e.value.status_code == 400 + + +@mock.patch("requests.post", mocked_post) +@mock.patch("requests.get", mocked_get) +@mock.patch("sseclient.SSEClient", mocked_sse_client) +@mock.patch( + "model_engine_server.inference.forwarding.forwarding.get_endpoint_config", + mocked_get_endpoint_config, +) +def test_streaming_forwarder_loader(): + fwd = LoadStreamingForwarder(serialize_results_as_string=True).load(None, None) # type: ignore + json_response = fwd({"ignore": "me"}) + _check_streaming_serialized(json_response) + + fwd = LoadStreamingForwarder(serialize_results_as_string=False).load(None, None) # type: ignore + response = fwd({"ignore": "me"}) + _check_streaming(response) diff --git a/model-engine/tests/unit/inference/test_http_forwarder.py b/model-engine/tests/unit/inference/test_http_forwarder.py new file mode 100644 index 00000000..0edfb444 --- /dev/null +++ b/model-engine/tests/unit/inference/test_http_forwarder.py @@ -0,0 +1,377 @@ +import json +import threading +from dataclasses import dataclass +from typing import Mapping +from unittest import mock + +import pytest +import requests_mock +from aioresponses import aioresponses +from fastapi import BackgroundTasks, FastAPI +from fastapi.responses import JSONResponse +from fastapi.testclient import TestClient +from model_engine_server.common.dtos.tasks import EndpointPredictV1Request +from model_engine_server.domain.entities.model_endpoint_entity import ModelEndpointConfig +from model_engine_server.inference.forwarding.forwarding import Forwarder +from model_engine_server.inference.forwarding.http_forwarder import ( + MultiprocessingConcurrencyLimiter, + get_concurrency_limiter, + get_forwarder_loader, + get_streaming_forwarder_loader, + init_app, + predict, +) +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( + DatadogInferenceMonitoringMetricsGateway, +) +from model_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler +from tests.unit.conftest import FakeStreamingStorageGateway + +PAYLOAD: Mapping[str, str] = {"hello": "world"} + + +class ExceptionCapturedThread(threading.Thread): + def __init__(self, target, args): + super().__init__(target=target, args=args) + self.ex = None + + def run(self): + try: + self._target(*self._args) + except Exception as e: + self.ex = e + + def join(self): + super().join() + if self.ex is not None: + raise self.ex + + +def mocked_get(*args, **kwargs): # noqa + @dataclass + class mocked_static_status_code: + status_code: int = 200 + + return mocked_static_status_code() + + +def mocked_post(*args, **kwargs): # noqa + @dataclass + class mocked_static_json: + status_code: int = 200 + + def json(self) -> dict: + return PAYLOAD # type: ignore + + return mocked_static_json() + + +def mocked_get_config(): + return { + "sync": { + "user_port": 5005, + "user_hostname": "localhost", + "use_grpc": False, + "predict_route": "/predict", + "healthcheck_route": "/readyz", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": True, + "forward_http_status": True, + }, + "stream": { + "user_port": 5005, + "user_hostname": "localhost", + "predict_route": "/stream", + "healthcheck_route": "/readyz", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": False, + }, + "max_concurrency": 42, + } + + +@pytest.fixture +def post_inference_hooks_handler(): + handler = PostInferenceHooksHandler( + endpoint_name="test_endpoint_name", + bundle_name="test_bundle_name", + post_inference_hooks=[], + user_id="test_user_id", + billing_queue="billing_queue", + billing_tags=[], + default_callback_url=None, + default_callback_auth=None, + monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id="test_endpoint_id", + endpoint_type="sync", + bundle_id="test_bundle_id", + labels={}, + streaming_storage_gateway=FakeStreamingStorageGateway(), + ) + return handler + + +@pytest.fixture +def post_inference_hooks_handler_with_logging(): + handler = PostInferenceHooksHandler( + endpoint_name="test_endpoint_name", + bundle_name="test_bundle_name", + post_inference_hooks=["logging"], + user_id="test_user_id", + billing_queue="billing_queue", + billing_tags=[], + default_callback_url=None, + default_callback_auth=None, + monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), + endpoint_id="test_endpoint_id", + endpoint_type="sync", + bundle_id="test_bundle_id", + labels={}, + streaming_storage_gateway=FakeStreamingStorageGateway(), + ) + return handler + + +@pytest.fixture +def mock_request(): + return EndpointPredictV1Request( + url="test_url", + return_pickled=False, + args={"x": 1}, + ) + + +@mock.patch( + "model_engine_server.inference.forwarding.http_forwarder.get_config", + mocked_get_config, +) +def test_get_forwarder_loader(): + loader = get_forwarder_loader() + assert loader.predict_route == "/predict" + + loader = get_forwarder_loader("/v1/chat/completions") + assert loader.predict_route == "/v1/chat/completions" + + +@mock.patch( + "model_engine_server.inference.forwarding.http_forwarder.get_config", + mocked_get_config, +) +def test_get_streaming_forwarder_loader(): + loader = get_streaming_forwarder_loader() + assert loader.predict_route == "/stream" + + loader = get_streaming_forwarder_loader("/v1/chat/completions") + assert loader.predict_route == "/v1/chat/completions" + + +@mock.patch( + "model_engine_server.inference.forwarding.http_forwarder.get_config", + mocked_get_config, +) +def test_get_concurrency_limiter(): + limiter = get_concurrency_limiter() + assert isinstance(limiter, MultiprocessingConcurrencyLimiter) + assert limiter.concurrency == 42 + + +@mock.patch("requests.post", mocked_post) +@mock.patch("requests.get", mocked_get) +@pytest.mark.skip(reason="This test is flaky") +def test_http_service_429(mock_request, post_inference_hooks_handler): + mock_forwarder = Forwarder( + "ignored", + model_engine_unwrap=True, + serialize_results_as_string=False, + post_inference_hooks_handler=post_inference_hooks_handler, + wrap_response=True, + forward_http_status=True, + ) + limiter = MultiprocessingConcurrencyLimiter(1, True) + t1 = ExceptionCapturedThread( + target=predict, args=(mock_request, BackgroundTasks(), mock_forwarder, limiter) + ) + t2 = ExceptionCapturedThread( + target=predict, args=(mock_request, BackgroundTasks(), mock_forwarder, limiter) + ) + t1.start() + t2.start() + t1.join() + with pytest.raises(Exception): # 429 thrown + t2.join() + + +def test_handler_response(post_inference_hooks_handler): + try: + post_inference_hooks_handler.handle( + request_payload=mock_request, response=PAYLOAD, task_id="test_task_id" + ) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +def test_handler_json_response(post_inference_hooks_handler): + try: + post_inference_hooks_handler.handle( + request_payload=mock_request, + response=JSONResponse(content=PAYLOAD), + task_id="test_task_id", + ) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +def test_handler_with_logging(post_inference_hooks_handler_with_logging): + try: + post_inference_hooks_handler_with_logging.handle( + request_payload=mock_request, + response=JSONResponse(content=PAYLOAD), + task_id="test_task_id", + ) + except Exception as e: + pytest.fail(f"Unexpected exception: {e}") + + +# Test the fastapi app + + +def mocked_get_config_with_extra_paths(): + return { + "sync": { + "user_port": 5005, + "user_hostname": "localhost", + "use_grpc": False, + "predict_route": "/predict", + "healthcheck_route": "/readyz", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": True, + "forward_http_status": True, + "extra_routes": ["/v1/chat/completions"], + }, + "stream": { + "user_port": 5005, + "user_hostname": "localhost", + "predict_route": "/stream", + "healthcheck_route": "/readyz", + "batch_route": None, + "model_engine_unwrap": True, + "serialize_results_as_string": False, + "extra_routes": ["/v1/chat/completions"], + }, + "max_concurrency": 42, + } + + +def get_predict_endpoint(config): + cfg_sync = config["sync"] + predict_endpoint = ( + f"http://{cfg_sync['user_hostname']}:{cfg_sync['user_port']}{cfg_sync['predict_route']}" + ) + return predict_endpoint + + +def get_healthcheck_endpoint(config): + cfg_sync = config["sync"] + healthcheck_endpoint = ( + f"http://{cfg_sync['user_hostname']}:{cfg_sync['user_port']}{cfg_sync['healthcheck_route']}" + ) + return healthcheck_endpoint + + +def get_stream_endpoint(config): + cfg_stream = config["stream"] + stream_endpoint = f"http://{cfg_stream['user_hostname']}:{cfg_stream['user_port']}{cfg_stream['predict_route']}" + return stream_endpoint + + +def get_chat_endpoint(config): + cfg_sync = config["sync"] + chat_endpoint = ( + f"http://{cfg_sync['user_hostname']}:{cfg_sync['user_port']}{cfg_sync['extra_routes'][0]}" + ) + return chat_endpoint + + +def mocked_get_endpoint_config(): + return ModelEndpointConfig( + endpoint_name="test_endpoint_name", + bundle_name="test_bundle_name", + ) + + +@pytest.fixture() +@mock.patch( + "model_engine_server.inference.forwarding.http_forwarder.get_config", + mocked_get_config_with_extra_paths, +) +@mock.patch( + "model_engine_server.inference.forwarding.forwarding.get_endpoint_config", + mocked_get_endpoint_config, +) +async def mocked_app() -> FastAPI: + with requests_mock.Mocker() as req_mock: + healthcheck_endpoint = get_healthcheck_endpoint(mocked_get_config_with_extra_paths()) + req_mock.get( + healthcheck_endpoint, + json={"status": "ok"}, + ) + app = await init_app() + return app + + +def wrap_request(request): + return {"url": "", "args": request} + + +def wrap_result(result): + return {"result": result} + + +@pytest.mark.anyio +@mock.patch( + "model_engine_server.inference.forwarding.http_forwarder.get_config", + mocked_get_config_with_extra_paths, +) +@mock.patch( + "model_engine_server.inference.forwarding.forwarding.get_endpoint_config", + mocked_get_endpoint_config, +) +async def test_mocked_app_success(mocked_app): + config = mocked_get_config_with_extra_paths() + config_sync = config["sync"] + # config_stream = config["stream"] + + predict_endpoint = get_predict_endpoint(config) + healthcheck_endpoint = get_healthcheck_endpoint(config) + + # stream_endpoint = get_stream_endpoint(config) + chat_endpoint = get_chat_endpoint(config) + + raw_payload = {"prompt": "Hello", "stream": False} + raw_result = {"message": "Hello World"} + + payload = wrap_request(raw_payload) + expected_result = wrap_result( + json.dumps(raw_result) if config_sync["serialize_results_as_string"] else raw_result + ) + with TestClient( + mocked_app + ) as client, aioresponses() as aio_mock, requests_mock.Mocker() as req_mock: + req_mock.get( + healthcheck_endpoint, + json={"status": "ok"}, + ) + aio_mock.post(predict_endpoint, status=200, payload=raw_result) + response = client.post("/predict", json=payload) + assert response.status_code == 200 + assert response.json() == expected_result + + aio_mock.post(chat_endpoint, status=200, payload=raw_result) + response = client.post("/v1/chat/completions", json=payload) + assert response.status_code == 200 + assert response.json() == expected_result + + # TODO: add tests for streaming; it's not as trivial as I'd hoped diff --git a/model-engine/tests/unit/inference/test_vllm_batch.py b/model-engine/tests/unit/inference/test_vllm_batch.py new file mode 100644 index 00000000..5462c6ae --- /dev/null +++ b/model-engine/tests/unit/inference/test_vllm_batch.py @@ -0,0 +1,450 @@ +import json +from unittest.mock import call, mock_open, patch + +import pytest +from model_engine_server.inference.batch_inference.vllm_batch import batch_inference, file_exists + + +@pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest" +) +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" +) +@patch("model_engine_server.inference.batch_inference.vllm_batch.generate_with_vllm") +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_s3_client") +@patch("subprocess.Popen") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + new_callable=mock_open, + read_data="Mocked content", +) +@patch("builtins.open", new_callable=mock_open, read_data="Mocked content") +async def test_batch_inference( + mock_builtins_open_func, + mock_open_func, + mock_popen, + mock_get_s3_client, + mock_generate_with_vllm, + mock_create_batch_completions_request_content, + mock_create_batch_completions_engine_request, + mock_vllm, + create_batch_completions_engine_request, + create_batch_completions_request_content, + mock_s3_client, + mock_process, + mock_completion_output, +): + # Mock the necessary objects and data + mock_popen.return_value = mock_process + mock_get_s3_client.return_value = mock_s3_client + mock_create_batch_completions_engine_request.model_validate_json.return_value = ( + create_batch_completions_engine_request + ) + mock_create_batch_completions_request_content.model_validate_json.return_value = ( + create_batch_completions_request_content + ) + + # Mock the generate_with_vllm function + mock_generate_with_vllm.return_value = [mock_completion_output] + + # Call the function + await batch_inference("this config data gets ignored because we mock model_validate_json") + + # Assertions + mock_create_batch_completions_engine_request.model_validate_json.assert_called_once() + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("output_data_path", "w"), + call().write(json.dumps([mock_completion_output.dict()])), + ], + any_order=True, + ) + + +@pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest" +) +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" +) +@patch("model_engine_server.inference.batch_inference.vllm_batch.generate_with_vllm") +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_s3_client") +@patch("subprocess.Popen") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + new_callable=mock_open, + read_data="Mocked content", +) +@patch("builtins.open", new_callable=mock_open, read_data="Mocked content") +async def test_batch_inference_failed_to_download_model_but_proceed( + mock_builtins_open_func, + mock_open_func, + mock_popen, + mock_get_s3_client, + mock_generate_with_vllm, + mock_create_batch_completions_request_content, + mock_create_batch_completions_engine_request, + mock_vllm, + create_batch_completions_engine_request, + create_batch_completions_request_content, + mock_s3_client, + mock_process, + mock_completion_output, +): + # Mock the necessary objects and data + mock_process.returncode = 1 # Failed to download model + mock_popen.return_value = mock_process + mock_get_s3_client.return_value = mock_s3_client + mock_create_batch_completions_engine_request.model_validate_json.return_value = ( + create_batch_completions_engine_request + ) + mock_create_batch_completions_request_content.model_validate_json.return_value = ( + create_batch_completions_request_content + ) + + # Mock the generate_with_vllm function + mock_generate_with_vllm.return_value = [mock_completion_output] + + # Call the function + await batch_inference("this config data gets ignored because we mock model_validate_json") + + # Assertions + mock_create_batch_completions_engine_request.model_validate_json.assert_called_once() + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("output_data_path", "w"), + call().write(json.dumps([mock_completion_output.dict()])), + ], + any_order=True, + ) + + +@pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest" +) +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" +) +@patch("model_engine_server.inference.batch_inference.vllm_batch.generate_with_vllm") +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_s3_client") +@patch("subprocess.Popen") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + new_callable=mock_open, + read_data="Mocked content", +) +@patch("builtins.open", new_callable=mock_open, read_data="Mocked content") +@patch("model_engine_server.inference.batch_inference.vllm_batch.os.getenv") +async def test_batch_inference_two_workers( + mock_getenv, + mock_builtins_open_func, + mock_open_func, + mock_popen, + mock_get_s3_client, + mock_generate_with_vllm, + mock_create_batch_completions_request_content, + mock_create_batch_completions_engine_request, + mock_vllm, + create_batch_completions_engine_request, + create_batch_completions_request_content, + mock_s3_client, + mock_process, + mock_completion_output, +): + # Mock the necessary objects and data + mock_popen.return_value = mock_process + mock_get_s3_client.return_value = mock_s3_client + create_batch_completions_engine_request.data_parallelism = 2 + mock_create_batch_completions_engine_request.model_validate_json.return_value = ( + create_batch_completions_engine_request + ) + mock_create_batch_completions_request_content.model_validate_json.return_value = ( + create_batch_completions_request_content + ) + + # Mock the generate_with_vllm function + mock_generate_with_vllm.return_value = [mock_completion_output] + + indexes = [1, 0] + + def side_effect(key, default): + if key == "JOB_COMPLETION_INDEX": + return indexes.pop(0) + return default + + mock_getenv.side_effect = side_effect + # Batch completion worker 1 + await batch_inference("this config data gets ignored because we mock model_validate_json") + + # Assertions + mock_create_batch_completions_engine_request.model_validate_json.assert_called_once() + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("output_data_path.1", "w"), + call().write(json.dumps([mock_completion_output.dict()])), + ], + any_order=True, + ) + + # Batch completion worker 0 + await batch_inference("this config data gets ignored because we mock model_validate_json") + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("output_data_path.1", "r"), + call("output_data_path.0", "w"), + call("output_data_path.0", "r"), + call("output_data_path", "w"), + call().write(json.dumps([mock_completion_output.dict()])), + call().write("["), + call().write(","), + call().write("]"), + ], + any_order=True, + ) + + +@pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest" +) +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" +) +@patch("model_engine_server.inference.batch_inference.vllm_batch.generate_with_vllm") +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_s3_client") +@patch("subprocess.Popen") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + new_callable=mock_open, + read_data="Mocked content", +) +@patch("builtins.open", new_callable=mock_open, read_data="Mocked content") +@patch("model_engine_server.inference.batch_inference.vllm_batch.os.getenv") +async def test_batch_inference_delete_chunks( + mock_getenv, + mock_builtins_open_func, + mock_open_func, + mock_popen, + mock_get_s3_client, + mock_generate_with_vllm, + mock_create_batch_completions_request_content, + mock_create_batch_completions_engine_request, + mock_vllm, + create_batch_completions_engine_request, + create_batch_completions_request_content, + mock_s3_client, + mock_process, + mock_completion_output, +): + # Mock the necessary objects and data + mock_popen.return_value = mock_process + mock_get_s3_client.return_value = mock_s3_client + create_batch_completions_engine_request.data_parallelism = 2 + create_batch_completions_engine_request.output_data_path = "s3://bucket/key" + mock_create_batch_completions_engine_request.model_validate_json.return_value = ( + create_batch_completions_engine_request + ) + mock_create_batch_completions_request_content.model_validate_json.return_value = ( + create_batch_completions_request_content + ) + + # Mock the generate_with_vllm function + mock_generate_with_vllm.return_value = [mock_completion_output] + + indexes = [1, 0] + + def side_effect(key, default): + if key == "JOB_COMPLETION_INDEX": + return indexes.pop(0) + return default + + mock_getenv.side_effect = side_effect + # Batch completion worker 1 + await batch_inference("this config data gets ignored because we mock model_validate_json") + + # Assertions + mock_create_batch_completions_engine_request.model_validate_json.assert_called_once() + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("s3://bucket/key.1", "w"), + call().write(json.dumps([mock_completion_output.dict()])), + ], + any_order=True, + ) + + # Batch completion worker 0 + await batch_inference("this config data gets ignored because we mock model_validate_json") + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("s3://bucket/key.1", "r"), + call("s3://bucket/key.0", "w"), + call("s3://bucket/key.0", "r"), + call("s3://bucket/key", "w"), + call().write(json.dumps([mock_completion_output.dict()])), + call().write("["), + call().write(","), + call().write("]"), + ], + any_order=True, + ) + + mock_s3_client.delete_object.assert_has_calls( + [call(Bucket="bucket", Key="key.0"), call(Bucket="bucket", Key="key.1")] + ) + + +def test_file_exists(): + mock_open_func = mock_open() + path = "test_path" + + with patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + mock_open_func, + ): + result = file_exists(path) + + mock_open_func.assert_called_once_with(path, "r") + assert result is True + + +def test_file_exists_no_such_key(): + path = "test_path" + + with patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + side_effect=IOError("No such key"), + ): + result = file_exists(path) + + assert result is False + + +@pytest.mark.asyncio +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_vllm_engine") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsEngineRequest" +) +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.CreateBatchCompletionsRequestContent" +) +@patch("model_engine_server.inference.batch_inference.vllm_batch.generate_with_vllm") +@patch("model_engine_server.inference.batch_inference.vllm_batch.get_s3_client") +@patch("model_engine_server.inference.batch_inference.vllm_batch.subprocess.Popen") +@patch("subprocess.run") +@patch( + "model_engine_server.inference.batch_inference.vllm_batch.smart_open.open", + new_callable=mock_open, + read_data="Mocked content", +) +@patch("builtins.open", new_callable=mock_open, read_data="Mocked content") +async def test_batch_inference_tool_completion( + mock_builtins_open_func, + mock_open_func, + mock_run, + mock_popen, + mock_get_s3_client, + mock_generate_with_vllm, + mock_create_batch_completions_request_content, + mock_create_batch_completions_engine_request, + mock_vllm, + create_batch_completions_tool_completion_request, + create_batch_completions_tool_completion_request_content, + mock_s3_client, + mock_process, + mock_tool_completion_output, + mock_tool_completion_output2, + mock_run_output, +): + # Mock the necessary objects and data + mock_run.return_value = mock_run_output + mock_popen.return_value = mock_process + mock_get_s3_client.return_value = mock_s3_client + mock_create_batch_completions_engine_request.model_validate_json.return_value = ( + create_batch_completions_tool_completion_request + ) + mock_create_batch_completions_request_content.model_validate_json.return_value = ( + create_batch_completions_tool_completion_request_content + ) + + # Mock the generate_with_vllm function + mock_generate_with_vllm.side_effect = [ + [mock_tool_completion_output], + [mock_tool_completion_output2], + ] + + # Call the function + await batch_inference("this config data gets ignored because we mock model_validate_json") + + # Assertions + mock_create_batch_completions_engine_request.model_validate_json.assert_called_once() + mock_open_func.assert_has_calls( + [ + call("input_data_path", "r"), + call("output_data_path", "w"), + call().write( + json.dumps( + [ + { + "text": "```python\nimport math\nprint(math.sqrt(2))\n```\n1.4142135623730951\n>>>\nFinal Answer: 4\n", + "num_prompt_tokens": 10, + "num_completion_tokens": 49, + "tokens": [ + {"token": "``", "log_prob": -0.1980377733707428}, + {"token": "`", "log_prob": -0.0037908137310296297}, + {"token": "python", "log_prob": -0.015637163072824478}, + {"token": "\n", "log_prob": -0.0010788579238578677}, + {"token": "import", "log_prob": -0.04351021721959114}, + {"token": " math", "log_prob": -0.0021214615553617477}, + {"token": "\n", "log_prob": -0.002169043058529496}, + {"token": "print", "log_prob": -0.06555093079805374}, + {"token": "(", "log_prob": -0.005272886715829372}, + {"token": "math", "log_prob": -0.009995171800255775}, + {"token": ".", "log_prob": -0.0002040654799202457}, + {"token": "sqrt", "log_prob": -0.00886327400803566}, + {"token": "(", "log_prob": -0.0015410225605592132}, + {"token": "2", "log_prob": -0.008573509752750397}, + {"token": "))", "log_prob": -0.010970987379550934}, + {"token": "\n", "log_prob": -0.002175347413867712}, + {"token": "``", "log_prob": -0.01911235973238945}, + {"token": "`", "log_prob": -0.0005327236140146852}, + {"token": "\n", "log_prob": -0.002304519060999155}, + {"token": "1", "log_prob": -0.10852570831775665}, + {"token": ".", "log_prob": -0.007146273739635944}, + {"token": "4", "log_prob": -0.003810290014371276}, + {"token": "1", "log_prob": -0.002774677239358425}, + {"token": "4", "log_prob": -0.16946221888065338}, + {"token": ".", "log_prob": -0.007678280584514141}, + {"token": ".", "log_prob": -0.021146666258573532}, + {"token": ".", "log_prob": -0.3870151937007904}, + {"token": "\n", "log_prob": -0.027081478387117386}, + {"token": "Final", "log_prob": -0.1980377733707428}, + { + "token": " Answer", + "log_prob": -0.0037908137310296297, + }, + {"token": ":", "log_prob": -0.015637163072824478}, + {"token": " ", "log_prob": -0.0010788579238578677}, + {"token": "4", "log_prob": -0.04351021721959114}, + {"token": "\n", "log_prob": -0.0021214615553617477}, + ], + } + ] + ) + ), + ], + any_order=True, + ) diff --git a/server/tests/unit/infra/gateways/conftest.py b/model-engine/tests/unit/infra/gateways/conftest.py similarity index 51% rename from server/tests/unit/infra/gateways/conftest.py rename to model-engine/tests/unit/infra/gateways/conftest.py index fe82601e..4ca93044 100644 --- a/server/tests/unit/infra/gateways/conftest.py +++ b/model-engine/tests/unit/infra/gateways/conftest.py @@ -1,6 +1,6 @@ import pytest -from llm_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest @pytest.fixture @@ -15,33 +15,22 @@ def create_resources_request_sync_pytorch( @pytest.fixture -def create_resources_request_async_tensorflow( - test_api_key: str, build_endpoint_request_async_tensorflow: BuildEndpointRequest +def create_resources_request_async_runnable_image( + test_api_key: str, build_endpoint_request_async_runnable_image: BuildEndpointRequest ) -> CreateOrUpdateResourcesRequest: create_resources_request = CreateOrUpdateResourcesRequest( - build_endpoint_request=build_endpoint_request_async_tensorflow, + build_endpoint_request=build_endpoint_request_async_runnable_image, image="test_image", ) return create_resources_request @pytest.fixture -def create_resources_request_async_custom( - test_api_key: str, build_endpoint_request_async_custom: BuildEndpointRequest +def create_resources_request_sync_runnable_image( + test_api_key: str, build_endpoint_request_sync_runnable_image: BuildEndpointRequest ) -> CreateOrUpdateResourcesRequest: create_resources_request = CreateOrUpdateResourcesRequest( - build_endpoint_request=build_endpoint_request_async_custom, - image="test_image", - ) - return create_resources_request - - -@pytest.fixture -def create_resources_request_sync_custom( - test_api_key: str, build_endpoint_request_sync_custom: BuildEndpointRequest -) -> CreateOrUpdateResourcesRequest: - create_resources_request = CreateOrUpdateResourcesRequest( - build_endpoint_request=build_endpoint_request_sync_custom, + build_endpoint_request=build_endpoint_request_sync_runnable_image, image="test_image", ) return create_resources_request @@ -49,8 +38,7 @@ def create_resources_request_sync_custom( @pytest.fixture def create_resources_request_streaming_runnable_image( - test_api_key: str, - build_endpoint_request_streaming_runnable_image: BuildEndpointRequest, + test_api_key: str, build_endpoint_request_streaming_runnable_image: BuildEndpointRequest ) -> CreateOrUpdateResourcesRequest: create_resources_request = CreateOrUpdateResourcesRequest( build_endpoint_request=build_endpoint_request_streaming_runnable_image, diff --git a/model-engine/tests/unit/infra/gateways/k8s_fake_objects.py b/model-engine/tests/unit/infra/gateways/k8s_fake_objects.py new file mode 100644 index 00000000..257b0db5 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/k8s_fake_objects.py @@ -0,0 +1,74 @@ +# Various fake k8s objects to be used in mocking out the python k8s api client +# Only classes are defined here. If you need to add various fields to the classes, please do so here. + +from dataclasses import dataclass, field +from datetime import datetime +from typing import List, Optional + + +@dataclass +class FakeK8sV1ObjectMeta: + name: str = "fake_name" + namespace: str = "fake_namespace" + annotations: dict = field(default_factory=dict) + labels: dict = field(default_factory=dict) + creation_timestamp: datetime = datetime(2021, 1, 1, 0, 0, 0, 0) + # TODO: everything else + + +@dataclass +class FakeK8sV1PodStatus: + phase: str = "Running" + # TODO: everything else + + +@dataclass +class FakeK8sV1JobStatus: + active: int = 0 + succeeded: int = 0 + failed: int = 0 + ready: int = 0 + terminating: int = 0 + completion_time: Optional[datetime] = None + + +@dataclass +class FakeK8sV1JobSpec: + completions: int = 1 + parallelism: int = 1 + + +@dataclass +class FakeK8sV1Job: + metadata: FakeK8sV1ObjectMeta = FakeK8sV1ObjectMeta() + status: FakeK8sV1JobStatus = FakeK8sV1JobStatus() + spec: FakeK8sV1JobSpec = FakeK8sV1JobSpec() + # TODO: spec, api_version, kind + + +@dataclass +class FakeK8sV1JobList: + items: List[FakeK8sV1Job] = field(default_factory=list) + + +@dataclass +class FakeK8sV1Pod: + metadata: FakeK8sV1ObjectMeta = FakeK8sV1ObjectMeta() + status: FakeK8sV1PodStatus = FakeK8sV1PodStatus() + # TODO: spec, api_version, kind + + +@dataclass +class FakeK8sV1PodList: + items: List[FakeK8sV1Pod] = field(default_factory=list) + + +@dataclass +class FakeK8sEnvVar: + name: str + value: str + + +@dataclass +class FakeK8sDeploymentContainer: + env: List[FakeK8sEnvVar] diff --git a/model-engine/tests/unit/infra/gateways/resources/example_lws_config.json b/model-engine/tests/unit/infra/gateways/resources/example_lws_config.json new file mode 100644 index 00000000..41793478 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/resources/example_lws_config.json @@ -0,0 +1,829 @@ +{ + "apiVersion": "leaderworkerset.x-k8s.io/v1", + "kind": "LeaderWorkerSet", + "metadata": { + "creationTimestamp": "2000-01-01T01:38:14Z", + "generation": 1, + "labels": { + "created_by": "userid000000", + "endpoint_id": "end_abcdefg", + "endpoint_name": "endpoint_name", + "env": "training", + "managed-by": "model-engine", + "owner": "userid000000", + "product": "testing", + "tags.datadoghq.com/env": "training", + "tags.datadoghq.com/service": "endpoint_name", + "tags.datadoghq.com/version": "05fb96620692a205e52d33980eff475d6a52748a", + "team": "infra", + "use_scale_launch_endpoint_network_policy": "true", + "user_id": "userid000000" + }, + "managedFields": [ + { + "apiVersion": "leaderworkerset.x-k8s.io/v1", + "fieldsType": "FieldsV1", + "fieldsV1": { + "f:metadata": { + "f:labels": { + ".": {}, + "f:created_by": {}, + "f:endpoint_id": {}, + "f:endpoint_name": {}, + "f:env": {}, + "f:managed-by": {}, + "f:owner": {}, + "f:product": {}, + "f:tags.datadoghq.com/env": {}, + "f:tags.datadoghq.com/service": {}, + "f:tags.datadoghq.com/version": {}, + "f:team": {}, + "f:use_scale_launch_endpoint_network_policy": {}, + "f:user_id": {} + } + }, + "f:spec": { + ".": {}, + "f:leaderWorkerTemplate": { + ".": {}, + "f:leaderTemplate": { + ".": {}, + "f:metadata": { + ".": {}, + "f:annotations": { + ".": {}, + "f:ad.datadoghq.com/main.logs": {}, + "f:kubernetes.io/change-cause": {} + }, + "f:labels": { + ".": {}, + "f:app": {}, + "f:created_by": {}, + "f:endpoint_id": {}, + "f:endpoint_name": {}, + "f:env": {}, + "f:managed-by": {}, + "f:owner": {}, + "f:product": {}, + "f:sidecar.istio.io/inject": {}, + "f:tags.datadoghq.com/env": {}, + "f:tags.datadoghq.com/service": {}, + "f:tags.datadoghq.com/version": {}, + "f:team": {}, + "f:use_scale_launch_endpoint_network_policy": {}, + "f:user_id": {}, + "f:version": {} + } + }, + "f:spec": { + ".": {}, + "f:affinity": { + ".": {}, + "f:podAffinity": { + ".": {}, + "f:preferredDuringSchedulingIgnoredDuringExecution": {} + } + }, + "f:containers": {}, + "f:nodeSelector": {}, + "f:priorityClassName": {}, + "f:serviceAccount": {}, + "f:terminationGracePeriodSeconds": {}, + "f:tolerations": {}, + "f:volumes": {} + } + }, + "f:restartPolicy": {}, + "f:size": {}, + "f:workerTemplate": { + ".": {}, + "f:metadata": { + ".": {}, + "f:annotations": { + ".": {}, + "f:ad.datadoghq.com/main.logs": {}, + "f:kubernetes.io/change-cause": {} + }, + "f:labels": { + ".": {}, + "f:app": {}, + "f:created_by": {}, + "f:endpoint_id": {}, + "f:endpoint_name": {}, + "f:env": {}, + "f:managed-by": {}, + "f:owner": {}, + "f:product": {}, + "f:sidecar.istio.io/inject": {}, + "f:tags.datadoghq.com/env": {}, + "f:tags.datadoghq.com/service": {}, + "f:tags.datadoghq.com/version": {}, + "f:team": {}, + "f:use_scale_launch_endpoint_network_policy": {}, + "f:user_id": {}, + "f:version": {} + } + }, + "f:spec": { + ".": {}, + "f:affinity": { + ".": {}, + "f:podAffinity": { + ".": {}, + "f:preferredDuringSchedulingIgnoredDuringExecution": {} + } + }, + "f:containers": {}, + "f:nodeSelector": {}, + "f:priorityClassName": {}, + "f:serviceAccount": {}, + "f:terminationGracePeriodSeconds": {}, + "f:tolerations": {}, + "f:volumes": {} + } + } + }, + "f:replicas": {}, + "f:startupPolicy": {} + } + }, + "manager": "OpenAPI-Generator", + "operation": "Update", + "time": "2000-01-01T01:38:14Z" + }, + { + "apiVersion": "leaderworkerset.x-k8s.io/v1", + "fieldsType": "FieldsV1", + "fieldsV1": { + "f:status": { + ".": {}, + "f:conditions": {}, + "f:hpaPodSelector": {} + } + }, + "manager": "manager", + "operation": "Update", + "subresource": "status", + "time": "2000-01-01T01:38:14Z" + } + ], + "name": "launch-endpoint-id-end-abcdefg", + "namespace": "scale-deploy", + "resourceVersion": "2289583184", + "uid": "1d66ad78-3148-41b3-83fd-fb71d7656fb1" + }, + "spec": { + "leaderWorkerTemplate": { + "leaderTemplate": { + "metadata": { + "annotations": { + "ad.datadoghq.com/main.logs": "[{\"service\": \"endpoint_name\", \"source\": \"python\"}]", + "kubernetes.io/change-cause": "Deployment at 2000-01-01 01:38:13.814158 UTC. Using deployment constructed from model bundle ID bun_cqi4v12d6mt002nap720, model bundle name endpoint_name, endpoint ID end_abcdefg" + }, + "labels": { + "app": "launch-endpoint-id-end-abcdefg", + "created_by": "userid000000", + "endpoint_id": "end_abcdefg", + "endpoint_name": "endpoint_name", + "env": "training", + "managed-by": "model-engine", + "owner": "userid000000", + "product": "testing", + "sidecar.istio.io/inject": "false", + "tags.datadoghq.com/env": "training", + "tags.datadoghq.com/service": "endpoint_name", + "tags.datadoghq.com/version": "05fb96620692a205e52d33980eff475d6a52748a", + "team": "infra", + "use_scale_launch_endpoint_network_policy": "true", + "user_id": "userid000000", + "version": "v1" + } + }, + "spec": { + "affinity": { + "podAffinity": { + "preferredDuringSchedulingIgnoredDuringExecution": [ + { + "podAffinityTerm": { + "labelSelector": { + "matchExpressions": [ + { + "key": "app", + "operator": "In", + "values": [ + "launch-endpoint-id-end-abcdefg" + ] + } + ] + }, + "topologyKey": "kubernetes.io/hostname" + }, + "weight": 1 + }, + { + "podAffinityTerm": { + "labelSelector": { + "matchExpressions": [ + { + "key": "3d45a96760a60018eb4a9d874e919aef", + "operator": "In", + "values": [ + "True" + ] + } + ] + }, + "topologyKey": "kubernetes.io/hostname" + }, + "weight": 100 + } + ] + } + }, + "containers": [ + { + "command": [ + "/usr/bin/dumb-init", + "--", + "ddtrace-run", + "python", + "-m", + "model_engine_server.inference.forwarding.http_forwarder", + "--config", + "/workspace/model-engine/model_engine_server/inference/configs/service--http_forwarder.yaml", + "--port", + "5000", + "--num-workers", + "2", + "--set", + "forwarder.sync.predict_route=/predict", + "--set", + "forwarder.stream.predict_route=/stream", + "--set", + "forwarder.sync.healthcheck_route=/health", + "--set", + "forwarder.stream.healthcheck_route=/health" + ], + "env": [ + { + "name": "DD_TRACE_ENABLED", + "value": "True" + }, + { + "name": "DD_REMOTE_CONFIGURATION_ENABLED", + "value": "false" + }, + { + "name": "DD_SERVICE", + "value": "endpoint_name" + }, + { + "name": "DD_ENV", + "value": "training" + }, + { + "name": "DD_VERSION", + "value": "05fb96620692a205e52d33980eff475d6a52748a" + }, + { + "name": "DD_AGENT_HOST", + "valueFrom": { + "fieldRef": { + "fieldPath": "status.hostIP" + } + } + }, + { + "name": "AWS_PROFILE", + "value": "aws-profile" + }, + { + "name": "AWS_CONFIG_FILE", + "value": "/opt/.aws/config" + }, + { + "name": "RESULTS_S3_BUCKET", + "value": "bucket" + }, + { + "name": "BASE_PATH", + "value": "/workspace" + }, + { + "name": "ML_INFRA_SERVICES_CONFIG_PATH", + "value": "/workspace/model-engine-internal/resources/configs/infra_config_training.yaml" + } + ], + "image": "000000000000.dkr.ecr.us-west-2.amazonaws.com/llm-engine:tag", + "imagePullPolicy": "IfNotPresent", + "name": "http-forwarder", + "ports": [ + { + "containerPort": 5000, + "name": "http", + "protocol": "TCP" + } + ], + "readinessProbe": { + "httpGet": { + "path": "/readyz", + "port": 5000 + }, + "initialDelaySeconds": 10, + "periodSeconds": 5, + "timeoutSeconds": 5 + }, + "resources": { + "limits": { + "cpu": "1", + "ephemeral-storage": "1G", + "memory": "2Gi" + }, + "requests": { + "cpu": "1", + "ephemeral-storage": "100M", + "memory": "100M" + } + }, + "volumeMounts": [ + { + "mountPath": "/opt/.aws/config", + "name": "config-volume", + "subPath": "config" + }, + { + "mountPath": "/workspace/user_config", + "name": "user-config", + "subPath": "raw_data" + }, + { + "mountPath": "/workspace/endpoint_config", + "name": "endpoint-config", + "subPath": "raw_data" + } + ] + }, + { + "command": [ + "/bin/bash", + "-c", + "./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' s3://bucket/tag/userid000000/model_weights/* model_files;/workspace/init_ray.sh leader --ray_cluster_size=$RAY_CLUSTER_SIZE --own_address=$K8S_OWN_POD_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local;python -m vllm_server --model model_files --tensor-parallel-size 1 --port 5005 --disable-log-requests--enforce-eager" + ], + "env": [ + { + "name": "VLLM_HOST_IP", + "value": "$(K8S_LEADER_NAME).$(K8S_LWS_NAME).$(K8S_OWN_NAMESPACE).svc.cluster.local" + }, + { + "name": "NCCL_SOCKET_IFNAME", + "value": "eth0" + }, + { + "name": "GLOO_SOCKET_IFNAME", + "value": "eth0" + }, + { + "name": "NCCL_DEBUG", + "value": "INFO" + }, + { + "name": "VLLM_LOGGING_LEVEL", + "value": "INFO" + }, + { + "name": "AWS_PROFILE", + "value": "aws-profile" + }, + { + "name": "AWS_CONFIG_FILE", + "value": "/opt/.aws/config" + }, + { + "name": "K8S_OWN_POD_NAME", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.name" + } + } + }, + { + "name": "K8S_OWN_NAMESPACE", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.namespace" + } + } + }, + { + "name": "K8S_LWS_NAME", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.labels['leaderworkerset.sigs.k8s.io/name']" + } + } + }, + { + "name": "K8S_LWS_CLUSTER_SIZE", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.annotations['leaderworkerset.sigs.k8s.io/size']" + } + } + }, + { + "name": "DD_TRACE_ENABLED", + "value": "true" + }, + { + "name": "DD_SERVICE", + "value": "endpoint_name" + }, + { + "name": "DD_ENV", + "value": "training" + }, + { + "name": "DD_VERSION", + "value": "05fb96620692a205e52d33980eff475d6a52748a" + }, + { + "name": "DD_AGENT_HOST", + "valueFrom": { + "fieldRef": { + "fieldPath": "status.hostIP" + } + } + } + ], + "image": "000000000000.dkr.ecr.us-west-2.amazonaws.com/vllm:0.5.3.post1", + "imagePullPolicy": "IfNotPresent", + "name": "lws-leader", + "ports": [ + { + "containerPort": 5005, + "name": "http", + "protocol": "TCP" + } + ], + "readinessProbe": { + "httpGet": { + "path": "/health", + "port": 5005 + }, + "initialDelaySeconds": 10, + "periodSeconds": 5, + "timeoutSeconds": 5 + }, + "resources": { + "limits": { + "cpu": "10", + "ephemeral-storage": "94Gi", + "memory": "40Gi", + "nvidia.com/gpu": "1" + }, + "requests": { + "cpu": "10", + "ephemeral-storage": "94Gi", + "memory": "40Gi", + "nvidia.com/gpu": "1" + } + }, + "volumeMounts": [ + { + "mountPath": "/opt/.aws/config", + "name": "config-volume", + "subPath": "config" + }, + { + "mountPath": "/dev/shm", + "name": "dshm" + }, + { + "mountPath": "/app/user_config", + "name": "user-config", + "subPath": "raw_data" + }, + { + "mountPath": "/app/endpoint_config", + "name": "endpoint-config", + "subPath": "raw_data" + } + ] + } + ], + "nodeSelector": { + "k8s.amazonaws.com/accelerator": "nvidia-hopper-h100", + "node-lifecycle": "normal" + }, + "priorityClassName": "model-engine-high-priority", + "serviceAccount": "aws-profile", + "terminationGracePeriodSeconds": 600, + "tolerations": [ + { + "effect": "NoSchedule", + "key": "nvidia.com/gpu", + "operator": "Exists" + } + ], + "volumes": [ + { + "configMap": { + "name": "aws-profile-config" + }, + "name": "config-volume" + }, + { + "configMap": { + "name": "launch-endpoint-id-end-abcdefg" + }, + "name": "user-config" + }, + { + "configMap": { + "name": "launch-endpoint-id-end-abcdefg-endpoint-config" + }, + "name": "endpoint-config" + }, + { + "emptyDir": { + "medium": "Memory" + }, + "name": "dshm" + } + ] + } + }, + "restartPolicy": "RecreateGroupOnPodRestart", + "size": 2, + "workerTemplate": { + "metadata": { + "annotations": { + "ad.datadoghq.com/main.logs": "[{\"service\": \"endpoint_name\", \"source\": \"python\"}]", + "kubernetes.io/change-cause": "Deployment at 2000-01-01 01:38:13.814158 UTC. Using deployment constructed from model bundle ID bun_cqi4v12d6mt002nap720, model bundle name endpoint_name, endpoint ID end_abcdefg" + }, + "labels": { + "app": "launch-endpoint-id-end-abcdefg", + "created_by": "userid000000", + "endpoint_id": "end_abcdefg", + "endpoint_name": "endpoint_name", + "env": "training", + "managed-by": "model-engine", + "owner": "userid000000", + "product": "testing", + "sidecar.istio.io/inject": "false", + "tags.datadoghq.com/env": "training", + "tags.datadoghq.com/service": "endpoint_name", + "tags.datadoghq.com/version": "05fb96620692a205e52d33980eff475d6a52748a", + "team": "infra", + "use_scale_launch_endpoint_network_policy": "true", + "user_id": "userid000000", + "version": "v1" + } + }, + "spec": { + "affinity": { + "podAffinity": { + "preferredDuringSchedulingIgnoredDuringExecution": [ + { + "podAffinityTerm": { + "labelSelector": { + "matchExpressions": [ + { + "key": "app", + "operator": "In", + "values": [ + "launch-endpoint-id-end-abcdefg" + ] + } + ] + }, + "topologyKey": "kubernetes.io/hostname" + }, + "weight": 1 + }, + { + "podAffinityTerm": { + "labelSelector": { + "matchExpressions": [ + { + "key": "3d45a96760a60018eb4a9d874e919aef", + "operator": "In", + "values": [ + "True" + ] + } + ] + }, + "topologyKey": "kubernetes.io/hostname" + }, + "weight": 100 + } + ] + } + }, + "containers": [ + { + "command": [ + "/bin/bash", + "-c", + "./s5cmd --numworkers 512 cp --concurrency 10 --include '*.model' --include '*.json' --include '*.safetensors' --exclude 'optimizer*' s3://bucket/key/userid000000/model_weights/* model_files;/workspace/init_ray.sh worker --ray_cluster_size=$RAY_CLUSTER_SIZE --ray_address=$K8S_LEADER_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local --own_address=$K8S_OWN_POD_NAME.$K8S_LWS_NAME.$K8S_OWN_NAMESPACE.svc.cluster.local" + ], + "env": [ + { + "name": "VLLM_HOST_IP", + "value": "$(K8S_LEADER_NAME).$(K8S_LWS_NAME).$(K8S_OWN_NAMESPACE).svc.cluster.local" + }, + { + "name": "NCCL_SOCKET_IFNAME", + "value": "eth0" + }, + { + "name": "GLOO_SOCKET_IFNAME", + "value": "eth0" + }, + { + "name": "NCCL_DEBUG", + "value": "INFO" + }, + { + "name": "VLLM_LOGGING_LEVEL", + "value": "INFO" + }, + { + "name": "AWS_PROFILE", + "value": "aws-profile" + }, + { + "name": "AWS_CONFIG_FILE", + "value": "/opt/.aws/config" + }, + { + "name": "K8S_OWN_POD_NAME", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.name" + } + } + }, + { + "name": "K8S_OWN_NAMESPACE", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.namespace" + } + } + }, + { + "name": "K8S_LWS_NAME", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.labels['leaderworkerset.sigs.k8s.io/name']" + } + } + }, + { + "name": "K8S_LWS_CLUSTER_SIZE", + "valueFrom": { + "fieldRef": { + "fieldPath": "metadata.annotations['leaderworkerset.sigs.k8s.io/size']" + } + } + }, + { + "name": "DD_TRACE_ENABLED", + "value": "true" + }, + { + "name": "DD_SERVICE", + "value": "endpoint_name" + }, + { + "name": "DD_ENV", + "value": "training" + }, + { + "name": "DD_VERSION", + "value": "05fb96620692a205e52d33980eff475d6a52748a" + }, + { + "name": "DD_AGENT_HOST", + "valueFrom": { + "fieldRef": { + "fieldPath": "status.hostIP" + } + } + } + ], + "image": "000000000000.dkr.ecr.us-west-2.amazonaws.com/vllm:0.5.3.post1", + "imagePullPolicy": "IfNotPresent", + "name": "lws-worker", + "ports": [ + { + "containerPort": 5005, + "name": "http", + "protocol": "TCP" + } + ], + "resources": { + "limits": { + "cpu": "10", + "ephemeral-storage": "94Gi", + "memory": "40Gi", + "nvidia.com/gpu": "1" + }, + "requests": { + "cpu": "10", + "ephemeral-storage": "94Gi", + "memory": "40Gi", + "nvidia.com/gpu": "1" + } + }, + "volumeMounts": [ + { + "mountPath": "/opt/.aws/config", + "name": "config-volume", + "subPath": "config" + }, + { + "mountPath": "/dev/shm", + "name": "dshm" + }, + { + "mountPath": "/app/user_config", + "name": "user-config", + "subPath": "raw_data" + }, + { + "mountPath": "/app/endpoint_config", + "name": "endpoint-config", + "subPath": "raw_data" + } + ] + } + ], + "nodeSelector": { + "k8s.amazonaws.com/accelerator": "nvidia-hopper-h100", + "node-lifecycle": "normal" + }, + "priorityClassName": "model-engine-high-priority", + "serviceAccount": "aws-profile", + "terminationGracePeriodSeconds": 600, + "tolerations": [ + { + "effect": "NoSchedule", + "key": "nvidia.com/gpu", + "operator": "Exists" + } + ], + "volumes": [ + { + "configMap": { + "name": "aws-profile-config" + }, + "name": "config-volume" + }, + { + "configMap": { + "name": "launch-endpoint-id-end-abcdefg" + }, + "name": "user-config" + }, + { + "configMap": { + "name": "launch-endpoint-id-end-abcdefg-endpoint-config" + }, + "name": "endpoint-config" + }, + { + "emptyDir": { + "medium": "Memory" + }, + "name": "dshm" + } + ] + } + } + }, + "replicas": 0, + "rolloutStrategy": { + "rollingUpdateConfiguration": { + "maxSurge": 0, + "maxUnavailable": 1 + }, + "type": "RollingUpdate" + }, + "startupPolicy": "LeaderCreated" + }, + "status": { + "conditions": [ + { + "lastTransitionTime": "2000-01-01T01:38:14Z", + "message": "All replicas are ready", + "reason": "AllGroupsReady", + "status": "True", + "type": "Available" + } + ], + "hpaPodSelector": "leaderworkerset.sigs.k8s.io/name=launch-endpoint-id-end-abcdefg,leaderworkerset.sigs.k8s.io/worker-index=0" + } + } \ No newline at end of file diff --git a/model-engine/tests/unit/infra/gateways/resources/test_image_cache_gateway.py b/model-engine/tests/unit/infra/gateways/resources/test_image_cache_gateway.py new file mode 100644 index 00000000..60fa39cb --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/resources/test_image_cache_gateway.py @@ -0,0 +1,66 @@ +from typing import Dict, Set +from unittest.mock import AsyncMock, patch + +import pytest +from model_engine_server.infra.gateways.resources.image_cache_gateway import ( + CachedImages, + ImageCacheGateway, +) + +MODULE_PATH = "model_engine_server.infra.gateways.resources.image_cache_gateway" + + +@pytest.fixture +def mock_apps_client(): + mock_client = AsyncMock() + with patch( + f"{MODULE_PATH}.get_kubernetes_apps_client", + return_value=mock_client, + ): + yield mock_client + + +@pytest.mark.asyncio +async def test_create_or_update_image_cache( + mock_apps_client, +): + gateway = ImageCacheGateway() + await gateway.create_or_update_image_cache( + CachedImages( + cpu=["cpu_image"], + a10=["a10_image"], + a100=["a100_image"], + t4=["t4_image"], + h100=["h100_image"], + ) + ) + + # Needs to correspond with model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml + expected_images: Dict[str, Set[str]] = { + "cpu": {"cpu_image"}, + "a10": {"a10_image"}, + "a100": {"a100_image"}, + "t4": {"t4_image"}, + "h100": {"h100_image"}, + } + + actual_images: Dict[str, Set[str]] = { + "cpu": set(), + "a10": set(), + "a100": set(), + "t4": set(), + "h100": set(), + } + + for call_args in mock_apps_client.create_namespaced_daemon_set.call_args_list: + _, kwargs = call_args + compute_type = kwargs["body"]["metadata"]["name"].split("-")[-1] + actual_images[compute_type] = set( + container["image"] + for container in kwargs["body"]["spec"]["template"]["spec"]["containers"] + ) + + for k in expected_images.keys(): + assert expected_images[k].issubset( + actual_images[k] + ), f"Missing {expected_images[k].difference(actual_images[k])}" diff --git a/server/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py similarity index 63% rename from server/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py rename to model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py index 6dbefbe1..84a5063f 100644 --- a/server/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py +++ b/model-engine/tests/unit/infra/gateways/resources/test_k8s_endpoint_resource_delegate.py @@ -1,42 +1,50 @@ -from dataclasses import dataclass +import json +import os from typing import Any, Dict, List from unittest.mock import AsyncMock, Mock, patch import pytest from kubernetes_asyncio.client.rest import ApiException -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest -from llm_engine_server.domain.entities import ( +from model_engine_server.common.config import hmi_config +from model_engine_server.common.dtos.resource_manager import CreateOrUpdateResourcesRequest +from model_engine_server.common.env_vars import GIT_TAG +from model_engine_server.domain.entities import ( + ModelBundle, ModelEndpointConfig, ModelEndpointType, ModelEndpointUserConfigState, ) -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( DATADOG_ENV_VAR, K8SEndpointResourceDelegate, - add_datadog_env_to_main_container, + add_datadog_env_to_container, get_main_container_from_deployment_template, + k8s_yaml_exists, load_k8s_yaml, ) -from llm_engine_server.infra.gateways.resources.k8s_resource_types import ( +from model_engine_server.infra.gateways.resources.k8s_resource_types import ( DictStrInt, DictStrStr, ResourceArguments, ) +from tests.unit.infra.gateways.k8s_fake_objects import FakeK8sDeploymentContainer, FakeK8sEnvVar -MODULE_PATH = "llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate" +MODULE_PATH = "model_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate" +EXAMPLE_LWS_CONFIG_PATH = os.path.abspath(os.path.join(__file__, "..", "example_lws_config.json")) +with open(EXAMPLE_LWS_CONFIG_PATH, "r") as f: + EXAMPLE_LWS_CONFIG = json.load(f) -@dataclass -class FakeK8sEnvVar: - name: str - value: str - -@dataclass -class FakeK8sDeploymentContainer: - env: List[FakeK8sEnvVar] +@pytest.fixture +def mock_get_kubernetes_cluster_version(): + mock_version = "1.26" + with patch( + f"{MODULE_PATH}.get_kubernetes_cluster_version", + return_value=mock_version, + ): + yield mock_version @pytest.fixture @@ -69,6 +77,16 @@ def mock_autoscaling_client(): yield mock_client +@pytest.fixture +def mock_policy_client(): + mock_client = AsyncMock() + with patch( + f"{MODULE_PATH}.get_kubernetes_policy_client", + return_value=mock_client, + ): + yield mock_client + + @pytest.fixture def mock_custom_objects_client(): mock_client = AsyncMock() @@ -118,10 +136,17 @@ def k8s_endpoint_resource_delegate( return gateway +def test_k8s_yaml_exists(): + # This is tied to + # llm-engine/model-engine/model_engine_server/infra/gateways/resources/templates/service_template_config_map_circleci.yaml + assert k8s_yaml_exists("image-cache-h100.yaml"), "image-cache-h100.yaml should exist" + assert not k8s_yaml_exists( + "image-cache-abc9001.yaml" + ), "image-cache-abc9001.yaml should not exist" + + @pytest.mark.parametrize("resource_arguments_type", ResourceArguments.__args__) -def test_resource_arguments_type_and_add_datadog_env_to_main_container( - resource_arguments_type, -): +def test_resource_arguments_type_and_add_datadog_env_to_main_container(resource_arguments_type): # Convert the name of the type to a kebab case string # e.g. "BatchJobOrchestrationJobArguments" -> "batch-job-orchestration-job-arguments" resource_arguments_type_name = resource_arguments_type.__name__ @@ -157,7 +182,8 @@ def test_resource_arguments_type_and_add_datadog_env_to_main_container( deployment_template = load_k8s_yaml(f"{resource_arguments_type_name}.yaml", resource_arguments) if "runnable-image" in resource_arguments_type_name: - add_datadog_env_to_main_container(deployment_template) + user_container = get_main_container_from_deployment_template(deployment_template) + add_datadog_env_to_container(deployment_template, user_container) user_container = get_main_container_from_deployment_template(deployment_template) @@ -178,9 +204,8 @@ def _verify_deployment_labels( labels = build_endpoint_request.labels endpoint_name = model_endpoint_record.name env = "circleci" - git_tag = "54f8f73bfb1cce62a2b42326ccf9f49b5b145126" - k8s_resource_group_name = f"llm-engine-endpoint-id-{model_endpoint_record.id.replace('_', '-')}" + k8s_resource_group_name = f"launch-endpoint-id-{model_endpoint_record.id.replace('_', '-')}" assert body["metadata"]["name"] == k8s_resource_group_name assert body["metadata"]["namespace"] == hmi_config.endpoint_namespace @@ -191,15 +216,15 @@ def _verify_deployment_labels( "user_id": user_id, "endpoint_id": model_endpoint_record.id, "endpoint_name": endpoint_name, - "managed-by": "llm-engine", + "managed-by": "model-engine", "owner": user_id, "team": labels["team"], "product": labels["product"], "env": env, "tags.datadoghq.com/env": env, "tags.datadoghq.com/service": endpoint_name, - "tags.datadoghq.com/version": git_tag, - "use_scale_llm_engine_endpoint_network_policy": "true", + "tags.datadoghq.com/version": GIT_TAG, + "use_scale_launch_endpoint_network_policy": "true", } assert body["metadata"]["labels"] == expected_labels @@ -209,7 +234,7 @@ def _verify_deployment_labels( "user_id": user_id, "endpoint_id": model_endpoint_record.id, "endpoint_name": endpoint_name, - "managed-by": "llm-engine", + "managed-by": "model-engine", "owner": user_id, "team": labels["team"], "product": labels["product"], @@ -217,8 +242,8 @@ def _verify_deployment_labels( "version": "v1", "tags.datadoghq.com/env": env, "tags.datadoghq.com/service": endpoint_name, - "tags.datadoghq.com/version": git_tag, - "use_scale_llm_engine_endpoint_network_policy": "true", + "tags.datadoghq.com/version": GIT_TAG, + "use_scale_launch_endpoint_network_policy": "true", } if model_endpoint_record.endpoint_type == ModelEndpointType.ASYNC: @@ -237,9 +262,8 @@ def _verify_non_deployment_labels( labels = build_endpoint_request.labels endpoint_name = model_endpoint_record.name env = "circleci" - git_tag = "54f8f73bfb1cce62a2b42326ccf9f49b5b145126" - k8s_resource_group_name = f"llm-engine-endpoint-id-{model_endpoint_record.id.replace('_', '-')}" + k8s_resource_group_name = f"launch-endpoint-id-{model_endpoint_record.id.replace('_', '-')}" assert k8s_resource_group_name in body["metadata"]["name"] assert body["metadata"]["namespace"] == hmi_config.endpoint_namespace @@ -247,7 +271,7 @@ def _verify_non_deployment_labels( expected_labels = { "created_by": user_id, - "managed-by": "llm-engine", + "managed-by": "model-engine", "owner": user_id, "user_id": user_id, "endpoint_id": model_endpoint_record.id, @@ -257,8 +281,8 @@ def _verify_non_deployment_labels( "env": env, "tags.datadoghq.com/env": env, "tags.datadoghq.com/service": endpoint_name, - "tags.datadoghq.com/version": git_tag, - "use_scale_llm_engine_endpoint_network_policy": "true", + "tags.datadoghq.com/version": GIT_TAG, + "use_scale_launch_endpoint_network_policy": "true", } assert body["metadata"]["labels"] == expected_labels @@ -275,22 +299,23 @@ def _verify_custom_object_plurals(call_args_list, expected_plurals: List[str]) - @pytest.mark.asyncio -async def test_create_async_endpoint_has_correct_labels( +async def test_create_async_endpoint_has_correct_labels_and_dest( k8s_endpoint_resource_delegate, mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, - create_resources_request_async_custom: CreateOrUpdateResourcesRequest, - create_resources_request_async_tensorflow: CreateOrUpdateResourcesRequest, + mock_get_kubernetes_cluster_version, + create_resources_request_async_runnable_image: CreateOrUpdateResourcesRequest, ): for request in [ - create_resources_request_async_custom, - create_resources_request_async_tensorflow, + create_resources_request_async_runnable_image, ]: - await k8s_endpoint_resource_delegate.create_or_update_resources( + dest = await k8s_endpoint_resource_delegate.create_or_update_resources( request, sqs_queue_name="my_queue", sqs_queue_url="https://my_queue" ) + assert dest == "my_queue" # Verify deployment labels create_deployment_call_args = mock_apps_client.create_namespaced_deployment.call_args @@ -328,6 +353,11 @@ async def test_create_async_endpoint_has_correct_labels( ) assert delete_custom_object_call_args_list == [] + # Verify PDB labels + create_pdb_call_args = mock_policy_client.create_namespaced_pod_disruption_budget.call_args + pdb_body = create_pdb_call_args.kwargs["body"] + _verify_non_deployment_labels(pdb_body, request) + if build_endpoint_request.model_endpoint_record.endpoint_type == ModelEndpointType.SYNC: assert create_custom_object_call_args_list == [] _verify_custom_object_plurals( @@ -339,20 +369,26 @@ async def test_create_async_endpoint_has_correct_labels( @pytest.mark.asyncio -async def test_create_streaming_endpoint_has_correct_labels( +async def test_create_streaming_endpoint_has_correct_labels_and_dest( k8s_endpoint_resource_delegate, mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, + mock_get_kubernetes_cluster_version, create_resources_request_streaming_runnable_image: CreateOrUpdateResourcesRequest, ): request = create_resources_request_streaming_runnable_image - await k8s_endpoint_resource_delegate.create_or_update_resources( + dest = await k8s_endpoint_resource_delegate.create_or_update_resources( request, sqs_queue_name="my_queue", sqs_queue_url="https://my_queue", ) + service_name = mock_core_client.create_namespaced_service.call_args.kwargs["body"]["metadata"][ + "name" + ] + assert dest == service_name # Verify deployment labels create_deployment_call_args = mock_apps_client.create_namespaced_deployment.call_args @@ -369,6 +405,11 @@ async def test_create_streaming_endpoint_has_correct_labels( config_map_body = create_config_map_call_args.kwargs["body"] _verify_non_deployment_labels(config_map_body, request) + # Verify PDB labels + create_pdb_call_args = mock_policy_client.create_namespaced_pod_disruption_budget.call_args + pdb_body = create_pdb_call_args.kwargs["body"] + _verify_non_deployment_labels(pdb_body, request) + # Verify HPA labels create_hpa_call_args = ( mock_autoscaling_client.create_namespaced_horizontal_pod_autoscaler.call_args @@ -385,7 +426,12 @@ async def test_create_streaming_endpoint_has_correct_labels( if optimize_costs: _verify_custom_object_plurals( call_args_list=create_custom_object_call_args_list, - expected_plurals=["verticalpodautoscalers"], + expected_plurals=["verticalpodautoscalers", "virtualservices", "destinationrules"], + ) + if build_endpoint_request.model_endpoint_record.endpoint_type == ModelEndpointType.SYNC: + _verify_custom_object_plurals( + call_args_list=create_custom_object_call_args_list, + expected_plurals=["virtualservices", "destinationrules"], ) mock_custom_objects_client.reset_mock() @@ -400,24 +446,28 @@ async def test_create_streaming_endpoint_has_correct_labels( @pytest.mark.asyncio -async def test_create_sync_endpoint_has_correct_labels( +async def test_create_sync_endpoint_has_correct_labels_and_dest( k8s_endpoint_resource_delegate, mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, - create_resources_request_sync_pytorch: CreateOrUpdateResourcesRequest, - create_resources_request_sync_custom: CreateOrUpdateResourcesRequest, + mock_get_kubernetes_cluster_version, + create_resources_request_sync_runnable_image: CreateOrUpdateResourcesRequest, ): for request in [ - create_resources_request_sync_pytorch, - create_resources_request_sync_custom, + create_resources_request_sync_runnable_image, ]: - await k8s_endpoint_resource_delegate.create_or_update_resources( + dest = await k8s_endpoint_resource_delegate.create_or_update_resources( request, sqs_queue_name="my_queue", sqs_queue_url="https://my_queue,", ) + service_name = mock_core_client.create_namespaced_service.call_args.kwargs["body"][ + "metadata" + ]["name"] + assert dest == service_name # Verify deployment labels create_deployment_call_args = mock_apps_client.create_namespaced_deployment.call_args @@ -441,6 +491,11 @@ async def test_create_sync_endpoint_has_correct_labels( hpa_body = create_hpa_call_args.kwargs["body"] _verify_non_deployment_labels(hpa_body, request) + # Verify PDB labels + create_pdb_call_args = mock_policy_client.create_namespaced_pod_disruption_budget.call_args + pdb_body = create_pdb_call_args.kwargs["body"] + _verify_non_deployment_labels(pdb_body, request) + # Make sure that an VPA is created if optimize_costs is True. build_endpoint_request = request.build_endpoint_request optimize_costs = build_endpoint_request.optimize_costs @@ -450,13 +505,20 @@ async def test_create_sync_endpoint_has_correct_labels( if optimize_costs: _verify_custom_object_plurals( call_args_list=create_custom_object_call_args_list, - expected_plurals=["verticalpodautoscalers"], + expected_plurals=["verticalpodautoscalers", "virtualservices", "destinationrules"], + ) + if build_endpoint_request.model_endpoint_record.endpoint_type == ModelEndpointType.SYNC: + _verify_custom_object_plurals( + call_args_list=create_custom_object_call_args_list, + expected_plurals=["virtualservices", "destinationrules"], ) mock_custom_objects_client.reset_mock() # Make sure that an VPA is created if optimize_costs is True. - optimize_costs = create_resources_request_sync_pytorch.build_endpoint_request.optimize_costs + optimize_costs = ( + create_resources_request_sync_runnable_image.build_endpoint_request.optimize_costs + ) create_vpa_call_args = mock_custom_objects_client.create_namespaced_custom_objects.call_args if optimize_costs: assert create_vpa_call_args is not None @@ -470,11 +532,13 @@ async def test_create_sync_endpoint_has_correct_k8s_service_type( mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, - create_resources_request_sync_pytorch: CreateOrUpdateResourcesRequest, + mock_get_kubernetes_cluster_version, + create_resources_request_sync_runnable_image: CreateOrUpdateResourcesRequest, ): await k8s_endpoint_resource_delegate.create_or_update_resources( - create_resources_request_sync_pytorch, + create_resources_request_sync_runnable_image, sqs_queue_name="my_queue", sqs_queue_url="https://my_queue", ) @@ -486,6 +550,48 @@ async def test_create_sync_endpoint_has_correct_k8s_service_type( assert service_body["spec"] is not None +@pytest.mark.asyncio +async def test_create_multinode_endpoint_creates_lws_and_correct_dest( + k8s_endpoint_resource_delegate, + mock_apps_client, + mock_core_client, + mock_autoscaling_client, + mock_policy_client, + mock_custom_objects_client, + mock_get_kubernetes_cluster_version, + create_resources_request_streaming_runnable_image: CreateOrUpdateResourcesRequest, + model_bundle_5: ModelBundle, +): + # Patch model bundle so that it supports multinode + model_bundle_5.flavor.worker_env = {"fake_env": "fake_value"} + model_bundle_5.flavor.worker_command = ["fake_command"] + create_resources_request_streaming_runnable_image.build_endpoint_request.model_endpoint_record.current_model_bundle = ( + model_bundle_5 + ) + create_resources_request_streaming_runnable_image.build_endpoint_request.model_endpoint_record.endpoint_type = ( + ModelEndpointType.STREAMING + ) + + create_resources_request_streaming_runnable_image.build_endpoint_request.nodes_per_worker = 2 + dest = await k8s_endpoint_resource_delegate.create_or_update_resources( + create_resources_request_streaming_runnable_image, + sqs_queue_name="my_queue", + sqs_queue_url="https://my_queue", + ) + service_name = mock_core_client.create_namespaced_service.call_args.kwargs["body"]["metadata"][ + "name" + ] + assert dest == service_name + # Verify call to custom objects client with LWS is made + create_custom_objects_call_args_list = ( + mock_custom_objects_client.create_namespaced_custom_object.call_args_list + ) + assert any( + call_args.kwargs["group"] == "leaderworkerset.x-k8s.io" + for call_args in create_custom_objects_call_args_list + ) + + @pytest.mark.asyncio async def test_create_endpoint_raises_k8s_endpoint_resource_delegate( k8s_endpoint_resource_delegate, @@ -523,8 +629,11 @@ async def test_get_resources_async_success( mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, ): + # Pretend that LWS get gives an ApiException since it doesn't exist + mock_custom_objects_client.get_namespaced_custom_object = AsyncMock(side_effect=ApiException) k8s_endpoint_resource_delegate.__setattr__( "_get_common_endpoint_params", Mock( @@ -550,7 +659,7 @@ async def test_get_resources_async_success( Mock(return_value=FakeK8sDeploymentContainer(env=[])), ) k8s_endpoint_resource_delegate.__setattr__( - "_get_llm_engine_container", + "_get_launch_container", Mock( return_value=FakeK8sDeploymentContainer( env=[FakeK8sEnvVar(name="PREWARM", value="true")] @@ -582,8 +691,11 @@ async def test_get_resources_sync_success( mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, ): + # Pretend that LWS get and keda get give an ApiException + mock_custom_objects_client.get_namespaced_custom_object = AsyncMock(side_effect=ApiException) k8s_endpoint_resource_delegate.__setattr__( "_get_common_endpoint_params", Mock( @@ -608,8 +720,7 @@ async def test_get_resources_sync_success( "_get_main_container", Mock(return_value=FakeK8sDeploymentContainer(env=[])) ) k8s_endpoint_resource_delegate.__setattr__( - "_get_llm_engine_container", - Mock(return_value=FakeK8sDeploymentContainer(env=[])), + "_get_launch_container", Mock(return_value=FakeK8sDeploymentContainer(env=[])) ) k8s_endpoint_resource_delegate.__setattr__( "_translate_k8s_config_maps_to_user_config_data", @@ -630,6 +741,40 @@ async def test_get_resources_sync_success( assert infra_state +@pytest.mark.asyncio +async def test_get_resources_multinode_success( + k8s_endpoint_resource_delegate, + mock_apps_client, + mock_core_client, + mock_autoscaling_client, + mock_policy_client, + mock_custom_objects_client, +): + k8s_endpoint_resource_delegate.__setattr__( + "_translate_k8s_config_maps_to_user_config_data", + Mock( + return_value=ModelEndpointUserConfigState( + app_config=None, + endpoint_config=ModelEndpointConfig( + endpoint_name="test_endpoint", + bundle_name="test_bundle", + post_inference_hooks=["callback"], + ), + ) + ), + ) + + mock_custom_objects_client.get_namespaced_custom_object = AsyncMock( + return_value=EXAMPLE_LWS_CONFIG + ) + + infra_state = await k8s_endpoint_resource_delegate.get_resources( + endpoint_id="", deployment_name="", endpoint_type=ModelEndpointType.STREAMING + ) + assert infra_state + assert infra_state.resource_state.nodes_per_worker == 2 + + @pytest.mark.asyncio async def test_delete_resources_invalid_endpoint_type_returns_false( k8s_endpoint_resource_delegate, @@ -646,6 +791,7 @@ async def test_delete_resources_async_success( mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, ): deleted = await k8s_endpoint_resource_delegate.delete_resources( @@ -660,9 +806,89 @@ async def test_delete_resources_sync_success( mock_apps_client, mock_core_client, mock_autoscaling_client, + mock_policy_client, mock_custom_objects_client, ): deleted = await k8s_endpoint_resource_delegate.delete_resources( endpoint_id="", deployment_name="", endpoint_type=ModelEndpointType.SYNC ) assert deleted + + +@pytest.mark.asyncio +async def test_delete_resources_multinode_success( + k8s_endpoint_resource_delegate, + mock_apps_client, + mock_core_client, + mock_autoscaling_client, + mock_policy_client, + mock_custom_objects_client, +): + mock_custom_objects_client.get_namespaced_custom_object = AsyncMock( + return_value=EXAMPLE_LWS_CONFIG + ) + mock_custom_objects_client.delete_namespaced_custom_object = AsyncMock() + deleted = await k8s_endpoint_resource_delegate.delete_resources( + endpoint_id="", deployment_name="", endpoint_type=ModelEndpointType.STREAMING + ) + assert deleted + delete_called_for_lws = False + for call_args in mock_custom_objects_client.delete_namespaced_custom_object.call_args_list: + # 'group' is kwargs in delete_namespaced_custom_object + if call_args[1]["group"] == "leaderworkerset.x-k8s.io": + delete_called_for_lws = True + break + assert delete_called_for_lws + + +@pytest.mark.asyncio +async def test_create_pdb( + k8s_endpoint_resource_delegate, + mock_policy_client, +): + # Mock the necessary objects and functions + pdb = { + "metadata": {"name": "test-pdb", "namespace": "test-namespace"}, + "spec": {"maxUnavailable": "50%"}, + } + name = "test-pdb" + + # Test successful creation + await k8s_endpoint_resource_delegate._create_pdb(pdb, name) + + mock_policy_client.create_namespaced_pod_disruption_budget.assert_called_once_with( + namespace=hmi_config.endpoint_namespace, + body=pdb, + ) + + # Test creation when PDB already exists + mock_policy_client.create_namespaced_pod_disruption_budget.side_effect = ApiException( + status=409 + ) + + existing_pdb = Mock() + existing_pdb.metadata.resource_version = "123" + mock_policy_client.read_namespaced_pod_disruption_budget.return_value = existing_pdb + + await k8s_endpoint_resource_delegate._create_pdb(pdb, name) + + mock_policy_client.read_namespaced_pod_disruption_budget.assert_called_once_with( + name=name, namespace=hmi_config.endpoint_namespace + ) + + expected_replace_pdb = pdb.copy() + expected_replace_pdb["metadata"]["resourceVersion"] = "123" + + mock_policy_client.replace_namespaced_pod_disruption_budget.assert_called_once_with( + name=name, + namespace=hmi_config.endpoint_namespace, + body=expected_replace_pdb, + ) + + # Test creation with other API exception + mock_policy_client.create_namespaced_pod_disruption_budget.side_effect = ApiException( + status=500 + ) + + with pytest.raises(ApiException): + await k8s_endpoint_resource_delegate._create_pdb(pdb, name) diff --git a/server/tests/unit/infra/gateways/resources/test_live_sqs_endpoint_resource_delegate.py b/model-engine/tests/unit/infra/gateways/resources/test_sqs_queue_endpoint_resource_delegate.py similarity index 85% rename from server/tests/unit/infra/gateways/resources/test_live_sqs_endpoint_resource_delegate.py rename to model-engine/tests/unit/infra/gateways/resources/test_sqs_queue_endpoint_resource_delegate.py index 952712a5..ae00ac43 100644 --- a/server/tests/unit/infra/gateways/resources/test_live_sqs_endpoint_resource_delegate.py +++ b/model-engine/tests/unit/infra/gateways/resources/test_sqs_queue_endpoint_resource_delegate.py @@ -4,14 +4,14 @@ import botocore.exceptions import pytest -from llm_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest -from llm_engine_server.domain.entities import ModelEndpointRecord -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( - LiveSQSEndpointResourceDelegate, +from model_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest +from model_engine_server.domain.entities import ModelEndpointRecord +from model_engine_server.domain.exceptions import EndpointResourceInfraException +from model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate import ( + SQSQueueEndpointResourceDelegate, ) -MODULE_PATH = "llm_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate" +MODULE_PATH = "model_engine_server.infra.gateways.resources.sqs_queue_endpoint_resource_delegate" EXPECTED_QUEUE_POLICY = """ { @@ -25,7 +25,7 @@ "AWS": "arn:aws:iam::000000000000:root" }, "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:llm-engine-endpoint-id-test_model_endpoint_id_3" + "Resource": "arn:aws:sqs:us-west-2:000000000000:launch-endpoint-id-test_model_endpoint_id_3" }, { "Effect": "Allow", @@ -33,29 +33,21 @@ "AWS": "arn:aws:iam::000000000000:role/default" }, "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:llm-engine-endpoint-id-test_model_endpoint_id_3" - }, - { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:role/ml_llm_engine" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:llm-engine-endpoint-id-test_model_endpoint_id_3" + "Resource": "arn:aws:sqs:us-west-2:000000000000:launch-endpoint-id-test_model_endpoint_id_3" } ] } """ EXPECTED_QUEUE_TAGS = { - "infra.scale.com/product": "MLInfraLLMEngineSQS", + "infra.scale.com/product": "MLInfraLaunchSQS", "infra.scale.com/team": "test_team", "infra.scale.com/contact": "yi.xu@scale.com", "infra.scale.com/customer": "AllCustomers", "infra.scale.com/financialOwner": "yi.xu@scale.com", - "Spellbook-Serve-Endpoint-Id": "test_model_endpoint_id_3", - "Spellbook-Serve-Endpoint-Name": "test_model_endpoint_name_3", - "Spellbook-Serve-Endpoint-Created-By": "test_user_id", + "Launch-Endpoint-Id": "test_model_endpoint_id_3", + "Launch-Endpoint-Name": "test_model_endpoint_name_3", + "Launch-Endpoint-Created-By": "test_user_id", } @@ -75,7 +67,7 @@ def _get_fake_botocore_exception(): @pytest.fixture def mock_create_async_sqs_client_create_queue(): create_queue_response = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-test_model_endpoint_id_3", + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-test_model_endpoint_id_3", "ResponseMetadata": { "RequestId": "9c05b1cc-d806-5cbd-bd4a-ea339c90e25f", "HTTPStatusCode": 200, @@ -108,7 +100,7 @@ def mock_create_async_sqs_client_create_queue(): @pytest.fixture def mock_create_async_sqs_client_get_queue_url(): get_queue_response = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-test_model_endpoint_id_3", + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-test_model_endpoint_id_3", } mock_sqs_client_session_val = AsyncMock() @@ -179,7 +171,7 @@ def mock_create_async_sqs_client_delete_queue(): mock_sqs_client_session_val.get_queue_url = AsyncMock() mock_sqs_client_session_val.get_queue_url.return_value = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-model_endpoint_id_1" + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-model_endpoint_id_1" } delete_response = { @@ -213,7 +205,7 @@ def mock_create_async_sqs_client_delete_queue_returns_non_200(): mock_sqs_client_session_val.get_queue_url = AsyncMock() mock_sqs_client_session_val.get_queue_url.return_value = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-model_endpoint_id_1" + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-model_endpoint_id_1" } delete_response = { @@ -247,7 +239,7 @@ def mock_create_async_sqs_client_delete_queue_throws_exception(): mock_sqs_client_session_val.get_queue_url = AsyncMock() mock_sqs_client_session_val.get_queue_url.return_value = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-model_endpoint_id_1" + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-model_endpoint_id_1" } mock_sqs_client_session_val.delete_queue = AsyncMock(side_effect=_get_fake_botocore_exception()) @@ -268,12 +260,12 @@ def mock_create_async_sqs_client_get_queue_attributes(): mock_sqs_client_session_val.get_queue_url = AsyncMock() mock_sqs_client_session_val.get_queue_url.return_value = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-model_endpoint_id_1" + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-model_endpoint_id_1" } get_queue_attributes_response = { "Attributes": { - "QueueArn": "arn:aws:sqs:us-west-2:000000000000:llm-engine-endpoint-id-model_endpoint_id_1", + "QueueArn": "arn:aws:sqs:us-west-2:000000000000:launch-endpoint-id-model_endpoint_id_1", "ApproximateNumberOfMessages": "0", "ApproximateNumberOfMessagesNotVisible": "0", "ApproximateNumberOfMessagesDelayed": "0", @@ -326,7 +318,7 @@ def mock_create_async_sqs_client_get_queue_attributes_queue_throws_exception(): mock_sqs_client_session_val.get_queue_url = AsyncMock() mock_sqs_client_session_val.get_queue_url.return_value = { - "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/llm-engine-endpoint-id-model_endpoint_id_1" + "QueueUrl": "https://us-west-2.queue.amazonaws.com/000000000000/launch-endpoint-id-model_endpoint_id_1" } mock_sqs_client_session_val.get_queue_attributes = AsyncMock( @@ -348,7 +340,7 @@ async def test_sqs_create_or_update_resources_endpoint_exists( build_endpoint_request_async_custom: BuildEndpointRequest, mock_create_async_sqs_client_get_queue_url, ): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") endpoint_record: ModelEndpointRecord = build_endpoint_request_async_custom.model_endpoint_record queue_name, queue_url = await delegate.create_queue_if_not_exists( endpoint_id=endpoint_record.id, @@ -360,7 +352,7 @@ async def test_sqs_create_or_update_resources_endpoint_exists( mock_create_async_sqs_client_get_queue_url.__aenter__.assert_called_once() expected_get_queue_url_args: Dict[str, Any] = { - "QueueName": "llm-engine-endpoint-id-test_model_endpoint_id_3", + "QueueName": "launch-endpoint-id-test_model_endpoint_id_3", } actual_get_queue_kwargs = ( mock_create_async_sqs_client_get_queue_url.__aenter__.return_value.get_queue_url.call_args.kwargs @@ -376,7 +368,7 @@ async def test_sqs_create_or_update_resources( build_endpoint_request_async_custom: BuildEndpointRequest, mock_create_async_sqs_client_create_queue, ): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") endpoint_record: ModelEndpointRecord = build_endpoint_request_async_custom.model_endpoint_record queue_name, queue_url = await delegate.create_queue_if_not_exists( endpoint_id=endpoint_record.id, @@ -388,7 +380,7 @@ async def test_sqs_create_or_update_resources( mock_create_async_sqs_client_create_queue.__aenter__.assert_called_once() expected_create_queue_args: Dict[str, Any] = { - "QueueName": "llm-engine-endpoint-id-test_model_endpoint_id_3", + "QueueName": "launch-endpoint-id-test_model_endpoint_id_3", "Attributes": { "VisibilityTimeout": "3600", "Policy": EXPECTED_QUEUE_POLICY, @@ -416,7 +408,7 @@ async def test_sqs_create_or_update_resources_throws_exception( build_endpoint_request_async_custom: BuildEndpointRequest, mock_create_async_sqs_client_create_queue_throws_exception, ): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") endpoint_record: ModelEndpointRecord = build_endpoint_request_async_custom.model_endpoint_record with pytest.raises(EndpointResourceInfraException): await delegate.create_queue_if_not_exists( @@ -432,7 +424,7 @@ async def test_sqs_create_or_update_resources_non_200( build_endpoint_request_async_custom: BuildEndpointRequest, mock_create_async_sqs_client_create_queue_returns_non_200, ): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") endpoint_record: ModelEndpointRecord = build_endpoint_request_async_custom.model_endpoint_record with pytest.raises(EndpointResourceInfraException): await delegate.create_queue_if_not_exists( @@ -445,18 +437,18 @@ async def test_sqs_create_or_update_resources_non_200( @pytest.mark.asyncio async def test_sqs_delete_resources(mock_create_async_sqs_client_delete_queue): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") await delegate.delete_queue(endpoint_id="model_endpoint_id_1") mock_create_async_sqs_client_delete_queue.__aenter__.assert_called_once() mock_create_async_sqs_client_delete_queue.__aenter__.return_value.get_queue_url.assert_called_once_with( - QueueName="llm-engine-endpoint-id-model_endpoint_id_1" + QueueName="launch-endpoint-id-model_endpoint_id_1" ) delete_call_kwargs = ( mock_create_async_sqs_client_delete_queue.__aenter__.return_value.delete_queue.call_args.kwargs ) - assert delete_call_kwargs["QueueUrl"].endswith("llm-engine-endpoint-id-model_endpoint_id_1") + assert delete_call_kwargs["QueueUrl"].endswith("launch-endpoint-id-model_endpoint_id_1") @pytest.mark.asyncio @@ -464,7 +456,7 @@ async def test_sqs_delete_resources_throws_exception( mock_create_async_sqs_client_delete_queue_throws_exception, ): with pytest.raises(EndpointResourceInfraException): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") await delegate.delete_queue(endpoint_id="model_endpoint_id_1") @@ -473,30 +465,28 @@ async def test_sqs_delete_resources_non_200( mock_create_async_sqs_client_delete_queue_returns_non_200, ): with pytest.raises(EndpointResourceInfraException): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") await delegate.delete_queue(endpoint_id="model_endpoint_id_1") @pytest.mark.asyncio -async def test_sqs_get_queue_attributes( - mock_create_async_sqs_client_get_queue_attributes, -): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") +async def test_sqs_get_queue_attributes(mock_create_async_sqs_client_get_queue_attributes): + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") response = await delegate.get_queue_attributes(endpoint_id="model_endpoint_id_1") mock_create_async_sqs_client_get_queue_attributes.__aenter__.assert_called_once() mock_create_async_sqs_client_get_queue_attributes.__aenter__.return_value.get_queue_url.assert_called_once_with( - QueueName="llm-engine-endpoint-id-model_endpoint_id_1" + QueueName="launch-endpoint-id-model_endpoint_id_1" ) get_queue_attributes_call_kwargs = ( mock_create_async_sqs_client_get_queue_attributes.__aenter__.return_value.get_queue_attributes.call_args.kwargs ) assert get_queue_attributes_call_kwargs["QueueUrl"].endswith( - "llm-engine-endpoint-id-model_endpoint_id_1" + "launch-endpoint-id-model_endpoint_id_1" ) - assert response["Attributes"]["QueueArn"].endswith("llm-engine-endpoint-id-model_endpoint_id_1") + assert response["Attributes"]["QueueArn"].endswith("launch-endpoint-id-model_endpoint_id_1") @pytest.mark.asyncio @@ -504,7 +494,7 @@ async def test_sqs_get_queue_attributes_queue_not_found( mock_create_async_sqs_client_get_queue_attributes_queue_not_found, ): with pytest.raises(EndpointResourceInfraException): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") await delegate.get_queue_attributes(endpoint_id="model_endpoint_id_1") @@ -513,5 +503,5 @@ async def test_sqs_get_queue_attributes_queue_throws_exception( mock_create_async_sqs_client_get_queue_attributes_queue_throws_exception, ): with pytest.raises(EndpointResourceInfraException): - delegate = LiveSQSEndpointResourceDelegate(sqs_profile="foobar") + delegate = SQSQueueEndpointResourceDelegate(sqs_profile="foobar") await delegate.get_queue_attributes(endpoint_id="model_endpoint_id_1") diff --git a/model-engine/tests/unit/infra/gateways/test_datadog_inference_monitoring_metrics_gateway.py b/model-engine/tests/unit/infra/gateways/test_datadog_inference_monitoring_metrics_gateway.py new file mode 100644 index 00000000..cb99d9b7 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/test_datadog_inference_monitoring_metrics_gateway.py @@ -0,0 +1,39 @@ +from unittest.mock import Mock + +import pytest +from datadog import statsd +from model_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( + DatadogInferenceMonitoringMetricsGateway, +) + + +@pytest.fixture(autouse=True) +def mock_statsd(): + # https://github.com/DataDog/datadogpy/issues/183 for how dd mocks statsd + statsd.socket = Mock() + # also mock the methods we use or may use, there might be more + statsd.gauge = Mock() + statsd.increment = Mock() + statsd.decrement = Mock() + statsd.histogram = Mock() + statsd.distribution = Mock() + + +@pytest.fixture +def datadog_inference_monitoring_metrics_gateway(): + return DatadogInferenceMonitoringMetricsGateway() + + +def test_datadog_inference_monitoring_metrics_gateway_batch_completion_metrics( + datadog_inference_monitoring_metrics_gateway, +): + model = "test_model" + use_tool = True + num_prompt_tokens = 100 + num_completion_tokens = 200 + is_finetuned = True + datadog_inference_monitoring_metrics_gateway.emit_batch_completions_metric( + model, use_tool, num_prompt_tokens, num_completion_tokens, is_finetuned + ) + statsd.increment.assert_called() + statsd.increment.reset_mock() diff --git a/model-engine/tests/unit/infra/gateways/test_datadog_monitoring_metrics_gateway.py b/model-engine/tests/unit/infra/gateways/test_datadog_monitoring_metrics_gateway.py new file mode 100644 index 00000000..e3e295a6 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/test_datadog_monitoring_metrics_gateway.py @@ -0,0 +1,106 @@ +from unittest.mock import Mock + +import pytest +from datadog import statsd +from model_engine_server.common.dtos.llms import TokenUsage +from model_engine_server.core.auth.authentication_repository import User +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MetricMetadata +from model_engine_server.infra.gateways import DatadogMonitoringMetricsGateway + + +@pytest.fixture(autouse=True) +def mock_statsd(): + # https://github.com/DataDog/datadogpy/issues/183 for how dd mocks statsd + statsd.socket = Mock() + # also mock the methods we use or may use, there might be more + statsd.gauge = Mock() + statsd.increment = Mock() + statsd.decrement = Mock() + statsd.histogram = Mock() + statsd.distribution = Mock() + + +@pytest.fixture +def sync_token_count(): + return TokenUsage( + num_prompt_tokens=100, + num_completion_tokens=200, + total_duration=30, + ) + + +@pytest.fixture +def streaming_token_count(): + return TokenUsage( + num_prompt_tokens=100, + num_completion_tokens=200, + total_duration=30, + time_to_first_token=5, + ) + + +@pytest.fixture +def datadog_monitoring_metrics_gateway(): + gateway = DatadogMonitoringMetricsGateway(prefix="model_engine_unit_test") + return gateway + + +def test_datadog_monitoring_metrics_gateway_build_metrics(datadog_monitoring_metrics_gateway): + datadog_monitoring_metrics_gateway.emit_attempted_build_metric() + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + datadog_monitoring_metrics_gateway.emit_successful_build_metric() + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + datadog_monitoring_metrics_gateway.emit_build_time_metric(300) + statsd.distribution.assert_called_once() + statsd.distribution.reset_mock() + datadog_monitoring_metrics_gateway.emit_image_build_cache_hit_metric("test_image") + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + datadog_monitoring_metrics_gateway.emit_image_build_cache_miss_metric("test_image_2") + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + datadog_monitoring_metrics_gateway.emit_docker_failed_build_metric() + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + + +def test_datadog_monitoring_metrics_gateway_db_metrics(datadog_monitoring_metrics_gateway): + datadog_monitoring_metrics_gateway.emit_database_cache_hit_metric() + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + datadog_monitoring_metrics_gateway.emit_database_cache_miss_metric() + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + + +def test_datadog_monitoring_metrics_gateway_route_call_metrics(datadog_monitoring_metrics_gateway): + metadata = MetricMetadata( + user=User(user_id="test_user", team_id="test_team", email="test_email"), + model_name="test_model", + ) + datadog_monitoring_metrics_gateway.emit_route_call_metric("test_route", metadata) + statsd.increment.assert_called_once() + statsd.increment.reset_mock() + + +def test_datadog_monitoring_metrics_gateway_token_count_metrics( + datadog_monitoring_metrics_gateway, sync_token_count, streaming_token_count +): + metadata = MetricMetadata( + user=User(user_id="test_user", team_id="test_team", email="test_email"), + model_name="test_model", + ) + datadog_monitoring_metrics_gateway.emit_token_count_metrics(sync_token_count, metadata) + statsd.increment.assert_called() + statsd.increment.reset_mock() + statsd.histogram.assert_called() + statsd.histogram.reset_mock() + datadog_monitoring_metrics_gateway.emit_token_count_metrics(streaming_token_count, metadata) + statsd.increment.assert_called() + statsd.increment.reset_mock() + statsd.histogram.assert_called() + statsd.histogram.reset_mock() + statsd.distribution.assert_called() + statsd.distribution.reset_mock() diff --git a/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py b/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py new file mode 100644 index 00000000..d4902b29 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/test_firehose_streaming_storage_gateway.py @@ -0,0 +1,91 @@ +from unittest import mock + +import pytest +from model_engine_server.domain.exceptions import StreamPutException +from model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway import ( + FirehoseStreamingStorageGateway, +) + +stream_name = "fake-stream" + +return_value = { + "RecordId": "fake-record-id", + "Encrypted": False, + "ResponseMetadata": {"HTTPStatusCode": 200}, +} + + +@pytest.fixture +def streaming_storage_gateway(): + gateway = FirehoseStreamingStorageGateway() + return gateway + + +@pytest.fixture +def fake_record(): + return {"RESPONSE_BODY": {"task_id": "fake-task-id"}} + + +def mock_sts_client(*args, **kwargs): + mock_client = mock.Mock() + mock_client.assume_role.return_value = { + "Credentials": { + "AccessKeyId": "fake-access-key-id", + "SecretAccessKey": "fake-secret-access-key", + "SessionToken": "fake-session-token", + } + } + return mock_client + + +def mock_firehose_client(*args, **kwargs): + mock_client = mock.Mock() + mock_client.put_record.return_value = return_value + return mock_client + + +def mock_firehose_client_with_exception(*args, **kwargs): + mock_client = mock.Mock() + mock_client.put_record.return_value = { + "RecordId": "fake-record-id", + "Encrypted": False, + "ResponseMetadata": {"HTTPStatusCode": 500}, + } + return mock_client + + +mock_sts_session = mock.Mock() +mock_sts_session.client.return_value = mock_sts_client() + + +mock_firehose_session = mock.Mock() +mock_firehose_session.client.return_value = mock_firehose_client() + + +mock_session_with_exception = mock.Mock() +mock_session_with_exception.client.return_value = mock_firehose_client_with_exception() + + +def test_firehose_streaming_storage_gateway_put_record(streaming_storage_gateway, fake_record): + with mock.patch( + "model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway.boto3.client", + mock_sts_client, + ), mock.patch( + "model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway.boto3.Session", + side_effect=[mock_sts_session, mock_firehose_session], + ): + assert streaming_storage_gateway.put_record(stream_name, fake_record) is return_value + + +def test_firehose_streaming_storage_gateway_put_record_with_exception( + streaming_storage_gateway, fake_record +): + with mock.patch( + "model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway.boto3.client", + mock_sts_client, + ), mock.patch( + "model_engine_server.inference.infra.gateways.firehose_streaming_storage_gateway.boto3.Session", + side_effect=[mock_sts_session, mock_session_with_exception], + ): + with pytest.raises(StreamPutException): + streaming_storage_gateway.put_record(stream_name, fake_record) diff --git a/server/tests/unit/infra/gateways/test_k8s_resource_parser.py b/model-engine/tests/unit/infra/gateways/test_k8s_resource_parser.py similarity index 91% rename from server/tests/unit/infra/gateways/test_k8s_resource_parser.py rename to model-engine/tests/unit/infra/gateways/test_k8s_resource_parser.py index 91741b7a..dd3462d5 100644 --- a/server/tests/unit/infra/gateways/test_k8s_resource_parser.py +++ b/model-engine/tests/unit/infra/gateways/test_k8s_resource_parser.py @@ -1,5 +1,5 @@ import pytest -from llm_engine_server.infra.gateways.k8s_resource_parser import ( +from model_engine_server.infra.gateways.k8s_resource_parser import ( get_per_worker_value_from_target_concurrency, get_target_concurrency_from_per_worker_value, parse_cpu_request, @@ -99,13 +99,13 @@ def test_parse_mem_request(): @pytest.mark.parametrize( "input_value", [ - "1", - "1.5", - "500m", - "5500m", + ("1", "1"), + ("1.5", "2"), + ("500m", "1"), + ("5500m", "6"), ], ) def test_get_target_concurrency_to_per_worker_value(input_value): assert get_target_concurrency_from_per_worker_value( - parse_cpu_request(str(get_per_worker_value_from_target_concurrency(input_value))) - ) == parse_cpu_request(input_value) + parse_cpu_request(str(get_per_worker_value_from_target_concurrency(input_value[0]))) + ) == parse_cpu_request(input_value[1]) diff --git a/server/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py similarity index 63% rename from server/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py rename to model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py index ce2309ef..8140b1c2 100644 --- a/server/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_async_model_inference_gateway.py @@ -1,9 +1,10 @@ import json +from datetime import datetime, timedelta from typing import Any import pytest -from llm_engine_server.common.dtos.tasks import GetAsyncTaskV1Response, TaskStatus -from llm_engine_server.infra.gateways import LiveAsyncModelEndpointInferenceGateway +from model_engine_server.common.dtos.tasks import GetAsyncTaskV1Response, TaskStatus +from model_engine_server.infra.gateways import LiveAsyncModelEndpointInferenceGateway @pytest.fixture @@ -22,10 +23,11 @@ def test_task_create_get_url( task_id = create_response.task_id task_queue_gateway: Any = fake_live_async_model_inference_gateway.task_queue_gateway assert len(task_queue_gateway.queue) == 1 - assert task_queue_gateway.queue[task_id]["args"] == [ - endpoint_predict_request_1[0].dict(), - endpoint_predict_request_1[0].return_pickled, - ] + assert task_queue_gateway.queue[task_id]["args"][0] == endpoint_predict_request_1[0].dict() + assert (datetime.now() - task_queue_gateway.queue[task_id]["args"][1]) < timedelta(seconds=1) + assert ( + task_queue_gateway.queue[task_id]["args"][2] == endpoint_predict_request_1[0].return_pickled + ) get_response_1 = fake_live_async_model_inference_gateway.get_task(task_id) assert get_response_1 == GetAsyncTaskV1Response(task_id=task_id, status=TaskStatus.PENDING) @@ -49,17 +51,19 @@ def test_task_create_get_args_callback( task_id = create_response.task_id task_queue_gateway: Any = fake_live_async_model_inference_gateway.task_queue_gateway assert len(task_queue_gateway.queue) == 1 - assert task_queue_gateway.queue[task_id]["args"] == [ - { - "args": endpoint_predict_request_2[0].args.__root__, - "url": None, - "cloudpickle": None, - "callback_auth": json.loads(endpoint_predict_request_2[0].callback_auth.json()), - "callback_url": endpoint_predict_request_2[0].callback_url, - "return_pickled": endpoint_predict_request_2[0].return_pickled, - }, - endpoint_predict_request_2[0].return_pickled, - ] + assert task_queue_gateway.queue[task_id]["args"][0] == { + "args": endpoint_predict_request_2[0].args.root, + "url": None, + "cloudpickle": None, + "callback_auth": json.loads(endpoint_predict_request_2[0].callback_auth.json()), + "callback_url": endpoint_predict_request_2[0].callback_url, + "return_pickled": endpoint_predict_request_2[0].return_pickled, + "destination_path": None, + } + assert (datetime.now() - task_queue_gateway.queue[task_id]["args"][1]) < timedelta(seconds=1) + assert ( + task_queue_gateway.queue[task_id]["args"][2] == endpoint_predict_request_2[0].return_pickled + ) get_response_1 = fake_live_async_model_inference_gateway.get_task(task_id) assert get_response_1 == GetAsyncTaskV1Response(task_id=task_id, status=TaskStatus.PENDING) diff --git a/server/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py similarity index 85% rename from server/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py rename to model-engine/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py index 53716d53..4112ac8b 100644 --- a/server/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_batch_job_progress_gateway.py @@ -1,5 +1,5 @@ -from llm_engine_server.domain.entities import BatchJobProgress -from llm_engine_server.infra.gateways import LiveBatchJobProgressGateway +from model_engine_server.domain.entities import BatchJobProgress +from model_engine_server.infra.gateways import LiveBatchJobProgressGateway def test_get_progress_empty(test_api_key: str, fake_filesystem_gateway): @@ -34,4 +34,4 @@ def test_update_progress(test_api_key: str, fake_filesystem_gateway): progress=BatchJobProgress(num_tasks_pending=4, num_tasks_completed=5), ) handle = fake_filesystem_gateway.mock_open() - handle.write.assert_called_once_with('{"num_tasks_pending": 4, "num_tasks_completed": 5}') + handle.write.assert_called_once_with('{"num_tasks_pending":4,"num_tasks_completed":5}') diff --git a/model-engine/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py new file mode 100644 index 00000000..b792b3d4 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py @@ -0,0 +1,255 @@ +from unittest.mock import AsyncMock, patch + +import pytest +from model_engine_server.domain.entities import BatchJobStatus +from model_engine_server.infra.gateways.live_docker_image_batch_job_gateway import ( + K8sEnvDict, + LiveDockerImageBatchJobGateway, + _add_list_values, + _check_batch_job_id_valid, + _get_job_id, +) +from tests.unit.infra.gateways.k8s_fake_objects import ( + FakeK8sV1Job, + FakeK8sV1JobList, + FakeK8sV1JobStatus, + FakeK8sV1ObjectMeta, + FakeK8sV1Pod, + FakeK8sV1PodList, + FakeK8sV1PodStatus, +) + +MODULE_PATH = "model_engine_server.infra.gateways.live_docker_image_batch_job_gateway" + + +@pytest.fixture +def mock_core_client(): + mock_client = AsyncMock() + with patch( + f"{MODULE_PATH}.get_kubernetes_core_client", + return_value=mock_client, + ): + yield mock_client + + +@pytest.fixture +def mock_batch_client(): + mock_client = AsyncMock() + with patch( + f"{MODULE_PATH}.get_kubernetes_batch_client", + return_value=mock_client, + ): + yield mock_client + + +@pytest.fixture +def docker_image_batch_job_gateway(): + gateway = LiveDockerImageBatchJobGateway() + return gateway + + +@pytest.mark.parametrize( + "active, succeeded, failed, pod_phase, pod_exists, expected_status", + [ + [1, 0, 0, "Running", True, BatchJobStatus.RUNNING], + [0, 1, 0, "Succeeded", True, BatchJobStatus.SUCCESS], + [0, 0, 1, "Failed", True, BatchJobStatus.FAILURE], + [1, 0, 0, "Pending", True, BatchJobStatus.PENDING], + [0, 0, 0, "Pending", False, BatchJobStatus.PENDING], + ], +) +@pytest.mark.asyncio +async def test_get_docker_image_batch_job_phase( + active, + succeeded, + failed, + pod_phase, + pod_exists, + expected_status, + docker_image_batch_job_gateway, + mock_core_client, + mock_batch_client, +): + if pod_exists: + pod_items = [ + FakeK8sV1Pod( + metadata=FakeK8sV1ObjectMeta( + labels={ + "job-name": "job-name", + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id", + } + ), + status=FakeK8sV1PodStatus( + phase=pod_phase, + ), + ) + ] + else: + pod_items = [] + + mock_core_client.list_namespaced_pod.return_value = FakeK8sV1PodList(items=pod_items) + mock_batch_client.list_namespaced_job.return_value = FakeK8sV1JobList( + items=[ + FakeK8sV1Job( + metadata=FakeK8sV1ObjectMeta( + name="job-name", + labels={ + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id", + }, + ), + status=FakeK8sV1JobStatus( + active=active, + succeeded=succeeded, + failed=failed, + ), + ) + ] + ) + + job = await docker_image_batch_job_gateway.get_docker_image_batch_job("launch_job_id") + assert job is not None + assert job.status == expected_status + + +@pytest.mark.asyncio +async def test_list_docker_image_batch_jobs( + docker_image_batch_job_gateway, + mock_core_client, + mock_batch_client, +): + mock_core_client.list_namespaced_pod.return_value = FakeK8sV1PodList( + items=[ + FakeK8sV1Pod( + metadata=FakeK8sV1ObjectMeta( + labels={ + "job-name": "job-name", + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id", + } + ), + status=FakeK8sV1PodStatus( + phase="Running", + ), + ), + FakeK8sV1Pod( + metadata=FakeK8sV1ObjectMeta( + labels={ + "job-name": "job-name2", + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id2", + } + ), + status=FakeK8sV1PodStatus( + phase="Succeeded", + ), + ), + ] + ) + mock_batch_client.list_namespaced_job.return_value = FakeK8sV1JobList( + items=[ + FakeK8sV1Job( + metadata=FakeK8sV1ObjectMeta( + name="job-name", + labels={ + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id", + }, + ), + status=FakeK8sV1JobStatus( + active=1, + succeeded=0, + failed=0, + ), + ), + FakeK8sV1Job( + metadata=FakeK8sV1ObjectMeta( + name="job-name2", + labels={ + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id2", + }, + ), + status=FakeK8sV1JobStatus( + active=0, + succeeded=1, + failed=0, + ), + ), + FakeK8sV1Job( + metadata=FakeK8sV1ObjectMeta( + name="job-name3", + labels={ + "owner": "owner", + "created_by": "created_by", + "trigger_id": "trigger_id", + "launch_job_id": "launch_job_id3", + }, + ), + status=FakeK8sV1JobStatus( + active=0, + succeeded=0, + failed=0, + ), + ), + ] + ) + + jobs = await docker_image_batch_job_gateway.list_docker_image_batch_jobs(owner="owner") + assert len(jobs) == 3 + job_ids_to_phases = {job.id: job.status for job in jobs} + assert job_ids_to_phases["launch_job_id"] == BatchJobStatus.RUNNING + assert job_ids_to_phases["launch_job_id2"] == BatchJobStatus.SUCCESS + assert job_ids_to_phases["launch_job_id3"] == BatchJobStatus.PENDING + + +# Small function functionality tests +def test_valid_job_ids_are_valid(): + for _ in range(20): + # _get_job_id() is nondeterministic + job_id = _get_job_id() + assert _check_batch_job_id_valid(job_id), f"job_id {job_id} apparently isn't valid" + + +def test_invalid_job_ids_are_invalid(): + assert not _check_batch_job_id_valid("spaces fail") + assert not _check_batch_job_id_valid("punctuation'") + assert not _check_batch_job_id_valid(".") + + +# test the adding list values +def test_add_list_values(): + default_values = [ + K8sEnvDict(name="default1", value="val1"), + K8sEnvDict(name="default2", value="val2"), + K8sEnvDict(name="default3", value="val3"), + ] + override_values = [ + K8sEnvDict(name="default1", value="override0"), + K8sEnvDict(name="override1", value="override1"), + K8sEnvDict(name="override2", value="override2"), + ] + expected_values = [ + K8sEnvDict(name="default1", value="val1"), + K8sEnvDict(name="default2", value="val2"), + K8sEnvDict(name="default3", value="val3"), + K8sEnvDict(name="override1", value="override1"), + K8sEnvDict(name="override2", value="override2"), + ] + + actual_values = _add_list_values(default_values, override_values) + actual_values.sort(key=lambda x: x["name"]) + assert expected_values == actual_values diff --git a/server/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py similarity index 73% rename from server/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py rename to model-engine/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py index b409463e..041b12aa 100644 --- a/server/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_model_endpoint_infra_gateway.py @@ -2,8 +2,8 @@ from unittest.mock import Mock import pytest -from llm_engine_server.domain.entities import ModelEndpoint -from llm_engine_server.infra.gateways import ( +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.infra.gateways import ( LiveModelEndpointInfraGateway, live_model_endpoint_infra_gateway, ) @@ -43,6 +43,7 @@ def test_create_model_endpoint_infra( memory=endpoint.infra_state.resource_state.memory, gpu_type=endpoint.infra_state.resource_state.gpu_type, storage=endpoint.infra_state.resource_state.storage, + nodes_per_worker=endpoint.infra_state.resource_state.nodes_per_worker, optimize_costs=bool(endpoint.infra_state.resource_state.optimize_costs), aws_role=endpoint.infra_state.aws_role, results_s3_bucket=endpoint.infra_state.results_s3_bucket, @@ -91,6 +92,15 @@ async def test_update_model_endpoint_infra( ), ) assert creation_task_id_1 + # Test existing billing tags don't get lost + endpoint_config = model_endpoint_1.infra_state.user_config_state.endpoint_config # type: ignore + billing_tags = endpoint_config.billing_tags # type: ignore + assert ( + fake_task_queue_gateway.get_task_args(creation_task_id_1)["kwargs"][ + "build_endpoint_request_json" + ].get("billing_tags") + == billing_tags + ) creation_task_id_2 = await model_endpoint_infra_gateway.update_model_endpoint_infra( model_endpoint_record=model_endpoint_1.record, @@ -100,8 +110,59 @@ async def test_update_model_endpoint_infra( gpu_type=model_endpoint_2.infra_state.resource_state.gpu_type, child_fn_info=model_endpoint_2.infra_state.child_fn_info, labels=model_endpoint_2.infra_state.labels, + billing_tags={ + "idempotencyKeyPrefix": "new_value_1", + "product": "value2", + "type": "hi", + "subType": "hi", + "tags": {"nested_tag_1": "nested_value_1"}, + "payee": "hi", + "payor": "hi", + "reference": {"referenceType": "hi", "referenceId": "hi"}, + }, ) assert creation_task_id_2 + # Inspect the value of billing_tags across the wire to make sure it's set correctly + # Test new billing tags overwrite existing ones + assert ( + fake_task_queue_gateway.get_task_args(creation_task_id_2)["kwargs"][ + "build_endpoint_request_json" + ] + .get("billing_tags") + .get("idempotencyKeyPrefix") + == "new_value_1" + ) + + +@pytest.mark.asyncio +async def test_update_multinode_endpoint_keeps_nodes_per_worker( + model_endpoint_infra_gateway: LiveModelEndpointInfraGateway, + model_endpoint_1: ModelEndpoint, + fake_task_queue_gateway, +): + model_endpoint_1.infra_state.resource_state.nodes_per_worker = 2 + resource_gateway: Any = model_endpoint_infra_gateway.resource_gateway + existing_infra_state = model_endpoint_1.infra_state + assert existing_infra_state is not None + live_model_endpoint_infra_gateway.generate_deployment_name = Mock( + return_value=existing_infra_state.deployment_name + ) + resource_gateway.add_resource(model_endpoint_1.record.id, existing_infra_state) + + creation_task_id_1 = await model_endpoint_infra_gateway.update_model_endpoint_infra( + model_endpoint_record=model_endpoint_1.record, + max_workers=2, + cpus=2, + memory=2, + storage=2, + ) + assert creation_task_id_1 + assert ( + fake_task_queue_gateway.get_task_args(creation_task_id_1)["kwargs"][ + "build_endpoint_request_json" + ].get("nodes_per_worker") + == 2 + ) @pytest.mark.asyncio diff --git a/server/tests/unit/infra/gateways/test_live_model_endpoints_schema_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_model_endpoints_schema_gateway.py similarity index 96% rename from server/tests/unit/infra/gateways/test_live_model_endpoints_schema_gateway.py rename to model-engine/tests/unit/infra/gateways/test_live_model_endpoints_schema_gateway.py index e60bbcff..9b3f2ad7 100644 --- a/server/tests/unit/infra/gateways/test_live_model_endpoints_schema_gateway.py +++ b/model-engine/tests/unit/infra/gateways/test_live_model_endpoints_schema_gateway.py @@ -1,14 +1,12 @@ import pytest -from llm_engine_server.domain.entities import ModelEndpoint -from llm_engine_server.infra.gateways.live_model_endpoints_schema_gateway import ( +from model_engine_server.domain.entities import ModelEndpoint +from model_engine_server.infra.gateways.live_model_endpoints_schema_gateway import ( LiveModelEndpointsSchemaGateway, ) @pytest.fixture -def live_model_endpoints_schema_gateway( - fake_filesystem_gateway, -) -> LiveModelEndpointsSchemaGateway: +def live_model_endpoints_schema_gateway(fake_filesystem_gateway) -> LiveModelEndpointsSchemaGateway: return LiveModelEndpointsSchemaGateway(filesystem_gateway=fake_filesystem_gateway) diff --git a/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py new file mode 100644 index 00000000..2bddfffa --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py @@ -0,0 +1,282 @@ +import json +from dataclasses import dataclass +from typing import Any, Dict, Tuple +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest +from model_engine_server.common.dtos.tasks import ( + SyncEndpointPredictV1Request, + SyncEndpointPredictV1Response, +) +from model_engine_server.domain.exceptions import InvalidRequestException, UpstreamServiceError +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway +from model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway import ( + LiveStreamingModelEndpointInferenceGateway, +) + + +@dataclass +class FakeIterator: + content: bytes = b'{"test": "content"}' + count: int = 0 + + def __aiter__(self): + return self + + async def __anext__(self): + self.count = self.count + 1 + if self.count == 1: + return b"data: " + self.content + if self.count in {2, 3}: + return b"\n" + if self.count == 4: + raise StopAsyncIteration + + +@dataclass +class FakeResponse: + def __init__(self, status: int, message_content: bytes = b'{"test": "content"}'): + self.status = status + self.message_content = message_content + self.content = FakeIterator(content=message_content) + + async def read(self): + return self.message_content + + +def _get_mock_client_session(fake_response: FakeResponse): + mock_post = AsyncMock(return_value=fake_response) + mock_client_session_val = AsyncMock() + mock_client_session_val.post = mock_post + mock_client_session_val.__aenter__ = AsyncMock(return_value=mock_client_session_val) + mock_client_session_val.__aexit__ = AsyncMock() + mock_client_session = MagicMock(return_value=mock_client_session_val) + return mock_client_session + + +def _get_mock_client_session_with_client_connector_error(): + mock_post = AsyncMock( + side_effect=aiohttp.ClientConnectorError(connection_key=None, os_error=OSError()) + ) + mock_client_session_val = AsyncMock() + mock_client_session_val.post = mock_post + mock_client_session_val.__aenter__ = AsyncMock(return_value=mock_client_session_val) + + async def _aexit(*exc): + pass + + mock_client_session_val.__aexit__ = AsyncMock(side_effect=_aexit) + mock_client_session = MagicMock(return_value=mock_client_session_val) + return mock_client_session + + +@pytest.mark.asyncio +async def test_make_request_with_retries_success( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + fake_response = FakeResponse(status=200) + mock_client_session = _get_mock_client_session(fake_response) + + with patch( + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + response = gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ) + count = 0 + async for message in response: + assert message == {"test": "content"} + count += 1 + assert count == 1 + + +@pytest.mark.asyncio +async def test_make_request_with_retries_failed_429( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + fake_response = FakeResponse(status=429) + mock_client_session = _get_mock_client_session(fake_response) + + with pytest.raises(UpstreamServiceError), patch( + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + async for response in gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ): + response + + +@pytest.mark.asyncio +async def test_make_request_with_retries_failed_traceback( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + fake_response = FakeResponse(status=500) + mock_client_session = _get_mock_client_session(fake_response) + + with pytest.raises(UpstreamServiceError), patch( + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + async for response in gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ): + response + + +@pytest.mark.asyncio +async def test_make_request_with_retries_failed_with_client_connector_error( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + mock_client_session = _get_mock_client_session_with_client_connector_error() + + with pytest.raises(UpstreamServiceError), patch( + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + async for response in gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ): + response + + +@pytest.mark.asyncio +async def test_streaming_predict_success( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + fake_response = FakeResponse(status=200) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + response = gateway.streaming_predict( + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", + ) + count = 0 + async for message in response: + assert isinstance(message, SyncEndpointPredictV1Response) + assert message.dict() == { + "status": "SUCCESS", + "result": {"test": "content"}, + "traceback": None, + } + count += 1 + assert count == 1 + + +@pytest.mark.asyncio +async def test_predict_raises_traceback_json( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + content = json.dumps({"detail": {"traceback": "test_traceback"}}).encode("utf-8") + fake_response = FakeResponse(status=500, message_content=content) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + response = gateway.streaming_predict( + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", + ) + count = 0 + async for message in response: + assert isinstance(message, SyncEndpointPredictV1Response) + assert message.dict() == { + "status": "FAILURE", + "result": None, + "traceback": "test_traceback", + } + count += 1 + assert count == 1 + + +@pytest.mark.asyncio +async def test_predict_raises_traceback_not_json( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + content = b"Test traceback content" + fake_response = FakeResponse(status=500, message_content=content) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + response = gateway.streaming_predict( + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", + ) + count = 0 + async for message in response: + assert isinstance(message, SyncEndpointPredictV1Response) + assert message.dict() == { + "status": "FAILURE", + "result": None, + "traceback": "Test traceback content", + } + count += 1 + assert count == 1 + + +@pytest.mark.asyncio +async def test_predict_upstream_raises_400( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveStreamingModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + content = json.dumps({"result": json.dumps({"error": "error"})}).encode("utf-8") + + fake_response = FakeResponse(status=400, message_content=content) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + with pytest.raises(InvalidRequestException): + response = gateway.streaming_predict( + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", + ) + async for message in response: + message diff --git a/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py new file mode 100644 index 00000000..608b73cf --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py @@ -0,0 +1,306 @@ +import json +from dataclasses import dataclass +from typing import Any, Dict, Tuple +from unittest.mock import AsyncMock, MagicMock, patch + +import aiohttp +import pytest +from model_engine_server.common.dtos.tasks import ( + SyncEndpointPredictV1Request, + SyncEndpointPredictV1Response, +) +from model_engine_server.domain.exceptions import InvalidRequestException, UpstreamServiceError +from model_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway +from model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway import ( + LiveSyncModelEndpointInferenceGateway, +) + + +@dataclass +class FakeResponse: + status: int + content: bytes = b"test_content" + body: Any = None + + async def read(self): + return self.content + + async def json(self): + return self.body if self.body else {"test_key": "test_value"} + + +def _get_mock_client_session(fake_response: FakeResponse): + mock_post = AsyncMock(return_value=fake_response) + mock_client_session_val = AsyncMock() + mock_client_session_val.post = mock_post + mock_client_session_val.__aenter__ = AsyncMock(return_value=mock_client_session_val) + mock_client_session_val.__aexit__ = AsyncMock() + mock_client_session = MagicMock(return_value=mock_client_session_val) + return mock_client_session + + +def _get_mock_client_session_with_client_connector_error(): + mock_post = AsyncMock( + side_effect=aiohttp.ClientConnectorError(connection_key=None, os_error=OSError()) + ) + mock_client_session_val = AsyncMock() + mock_client_session_val.post = mock_post + mock_client_session_val.__aenter__ = AsyncMock(return_value=mock_client_session_val) + + async def _aexit(*exc): + pass + + mock_client_session_val.__aexit__ = AsyncMock(side_effect=_aexit) + mock_client_session = MagicMock(return_value=mock_client_session_val) + return mock_client_session + + +@pytest.mark.asyncio +async def test_make_request_with_retries_success( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + fake_response = FakeResponse(status=200) + mock_client_session = _get_mock_client_session(fake_response) + + with patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + response = await gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ) + assert response == {"test_key": "test_value"} + + +@pytest.mark.asyncio +async def test_make_request_with_retries_failed_429( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + fake_response = FakeResponse(status=429) + mock_client_session = _get_mock_client_session(fake_response) + + with pytest.raises(UpstreamServiceError), patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + await gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ) + + +@pytest.mark.asyncio +async def test_make_request_with_retries_failed_traceback( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + fake_response = FakeResponse(status=500) + mock_client_session = _get_mock_client_session(fake_response) + + with pytest.raises(UpstreamServiceError), patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + await gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ) + + +@pytest.mark.asyncio +async def test_make_request_with_retries_failed_with_client_connector_error( + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + mock_client_session = _get_mock_client_session_with_client_connector_error() + + with pytest.raises(UpstreamServiceError), patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + await gateway.make_request_with_retries( + "test_request_url", {}, 0.05, 2, "test_endpoint_name" + ) + + +@pytest.mark.asyncio +async def test_predict_success( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + fake_response = FakeResponse(status=200, body={"test_key": "test_value"}) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + response = await gateway.predict( + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", + ) + assert isinstance(response, SyncEndpointPredictV1Response) + assert response.dict() == { + "status": "SUCCESS", + "result": {"test_key": "test_value"}, + "traceback": None, + } + + +@pytest.mark.asyncio +async def test_predict_raises_traceback_json( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + content = json.dumps({"detail": {"traceback": "test_traceback"}}).encode("utf-8") + fake_response = FakeResponse(status=500, content=content) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + response = await gateway.predict( + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", + ) + assert isinstance(response, SyncEndpointPredictV1Response) + assert response.dict() == { + "status": "FAILURE", + "result": None, + "traceback": "test_traceback", + } + + +@pytest.mark.asyncio +async def test_predict_raises_traceback_not_json( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + content = b"Test traceback content" + fake_response = FakeResponse(status=500, content=content) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + response = await gateway.predict( + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", + ) + assert isinstance(response, SyncEndpointPredictV1Response) + assert response.dict() == { + "status": "FAILURE", + "result": None, + "traceback": "Test traceback content", + } + + +@pytest.mark.asyncio +async def test_predict_raises_traceback_wrapped( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + content = json.dumps( + {"result": json.dumps({"detail": {"traceback": "test_traceback"}})} + ).encode("utf-8") + fake_response = FakeResponse(status=500, content=content) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + response = await gateway.predict( + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", + ) + assert isinstance(response, SyncEndpointPredictV1Response) + assert response.dict() == { + "status": "FAILURE", + "result": None, + "traceback": "test_traceback", + } + + +@pytest.mark.asyncio +async def test_predict_raises_traceback_wrapped_detail_array( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + content = json.dumps({"result": json.dumps({"detail": [{"error": "error"}]})}).encode("utf-8") + fake_response = FakeResponse(status=500, content=content) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + response = await gateway.predict( + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", + ) + assert isinstance(response, SyncEndpointPredictV1Response) + assert response.dict() == { + "status": "FAILURE", + "result": None, + "traceback": """{"detail":[{"error":"error"}]}""", + } + + +@pytest.mark.asyncio +async def test_predict_upstream_raises_400( + sync_endpoint_predict_request_1: Tuple[SyncEndpointPredictV1Request, Dict[str, Any]], + fake_monitoring_metrics_gateway: MonitoringMetricsGateway, +): + gateway = LiveSyncModelEndpointInferenceGateway( + monitoring_metrics_gateway=fake_monitoring_metrics_gateway, use_asyncio=True + ) + + content = json.dumps({"result": json.dumps({"error": "error"})}).encode("utf-8") + fake_response = FakeResponse(status=400, content=content) + mock_client_session = _get_mock_client_session(fake_response) + with patch( + "model_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", + mock_client_session, + ): + # assert that the exception is raised + with pytest.raises(InvalidRequestException): + await gateway.predict( + topic="test_topic", + predict_request=sync_endpoint_predict_request_1[0], + endpoint_name="test_name", + ) diff --git a/model-engine/tests/unit/infra/gateways/test_s3_llm_artifact_gateway.py b/model-engine/tests/unit/infra/gateways/test_s3_llm_artifact_gateway.py new file mode 100644 index 00000000..9e989959 --- /dev/null +++ b/model-engine/tests/unit/infra/gateways/test_s3_llm_artifact_gateway.py @@ -0,0 +1,87 @@ +from typing import List +from unittest import mock + +import pytest +from model_engine_server.common.config import hmi_config +from model_engine_server.infra.gateways.s3_llm_artifact_gateway import S3LLMArtifactGateway + + +@pytest.fixture +def llm_artifact_gateway(): + gateway = S3LLMArtifactGateway() + return gateway + + +@pytest.fixture +def fake_files(): + return ["fake-prefix/fake1", "fake-prefix/fake2", "fake-prefix/fake3", "fake-prefix-ext/fake1"] + + +def mock_boto3_session(fake_files: List[str]): + mock_session = mock.Mock() + mock_bucket = mock.Mock() + mock_objects = mock.Mock() + + def filter_files(*args, **kwargs): + prefix = kwargs["Prefix"] + return [mock.Mock(key=file) for file in fake_files if file.startswith(prefix)] + + mock_session.return_value.resource.return_value.Bucket.return_value = mock_bucket + mock_bucket.objects = mock_objects + mock_objects.filter.side_effect = filter_files + + mock_bucket.download_file.return_value = None + return mock_session + + +@mock.patch( + "model_engine_server.infra.gateways.s3_llm_artifact_gateway.os.makedirs", + lambda *args, **kwargs: None, # noqa +) +def test_s3_llm_artifact_gateway_download_folder(llm_artifact_gateway, fake_files): + prefix = "/".join(fake_files[0].split("/")[:-1]) + "/" + uri_prefix = f"s3://fake-bucket/{prefix}" + target_dir = "fake-target" + + expected_files = [ + f"{target_dir}/{file.split('/')[-1]}" for file in fake_files if file.startswith(prefix) + ] + with mock.patch( + "model_engine_server.infra.gateways.s3_llm_artifact_gateway.boto3.Session", + mock_boto3_session(fake_files), + ): + assert llm_artifact_gateway.download_files(uri_prefix, target_dir) == expected_files + + +@mock.patch( + "model_engine_server.infra.gateways.s3_llm_artifact_gateway.os.makedirs", + lambda *args, **kwargs: None, # noqa +) +def test_s3_llm_artifact_gateway_download_file(llm_artifact_gateway, fake_files): + file = fake_files[1] + uri = f"s3://fake-bucket/{file}" + target = f"fake-target/{file}" + + with mock.patch( + "model_engine_server.infra.gateways.s3_llm_artifact_gateway.boto3.Session", + mock_boto3_session(fake_files), + ): + assert llm_artifact_gateway.download_files(uri, target) == [target] + + +def test_s3_llm_artifact_gateway_get_model_weights(llm_artifact_gateway): + owner = "fakeuser" + model_name = "fakemodel" + fake_files = [f"{owner}/models--{model_name}/fake1", f"{owner}/models--{model_name}/fake2"] + + s3_prefix = hmi_config.hf_user_fine_tuned_weights_prefix + weights_prefix = "/".join(s3_prefix.replace("s3://", "").split("/")[1:]) + fake_model_weights = [f"{weights_prefix}/{file}" for file in fake_files] + expected_model_files = [f"{s3_prefix}/{file}" for file in fake_files] + with mock.patch( + "model_engine_server.infra.gateways.s3_llm_artifact_gateway.boto3.Session", + mock_boto3_session(fake_model_weights), + ): + assert ( + llm_artifact_gateway.get_model_weights_urls(owner, model_name) == expected_model_files + ) diff --git a/server/tests/unit/infra/repositories/conftest.py b/model-engine/tests/unit/infra/repositories/conftest.py similarity index 97% rename from server/tests/unit/infra/repositories/conftest.py rename to model-engine/tests/unit/infra/repositories/conftest.py index dcd0260a..dbf0109e 100644 --- a/server/tests/unit/infra/repositories/conftest.py +++ b/model-engine/tests/unit/infra/repositories/conftest.py @@ -2,10 +2,10 @@ from typing import Callable, Optional, Union import pytest -from llm_engine_server.db.models import BatchJob, Bundle -from llm_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle -from llm_engine_server.db.models import Endpoint -from llm_engine_server.domain.entities import ( +from model_engine_server.db.models import BatchJob, Bundle +from model_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle +from model_engine_server.db.models import Endpoint +from model_engine_server.domain.entities import ( BatchJobRecord, GpuType, ModelBundle, @@ -170,7 +170,7 @@ def orm_model_bundle_4(test_api_key: str) -> Bundle: "ecr_repo": "test_repo", "image_tag": "test_tag", }, - packaging_type="cloudpickle", + packaging_type="lira", app_config=None, ) model_bundle.id = "test_model_bundle_id_4" @@ -205,7 +205,7 @@ def orm_model_bundle_5(test_api_key: str) -> Bundle: "ecr_repo": "test_repo", "image_tag": "test_tag", }, - packaging_type="cloudpickle", + packaging_type="lira", app_config=None, ) model_bundle.id = "test_model_bundle_id_5" @@ -276,6 +276,7 @@ def entity_model_endpoint_infra_state() -> ModelEndpointInfraState: memory="1G", gpu_type=GpuType.NVIDIA_TESLA_T4, storage="10G", + nodes_per_worker=1, optimize_costs=False, ), user_config_state=ModelEndpointUserConfigState( diff --git a/server/tests/unit/infra/repositories/test_db_batch_job_record_repository.py b/model-engine/tests/unit/infra/repositories/test_db_batch_job_record_repository.py similarity index 97% rename from server/tests/unit/infra/repositories/test_db_batch_job_record_repository.py rename to model-engine/tests/unit/infra/repositories/test_db_batch_job_record_repository.py index 95859158..214a1ebc 100644 --- a/server/tests/unit/infra/repositories/test_db_batch_job_record_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_db_batch_job_record_repository.py @@ -3,10 +3,10 @@ from unittest.mock import AsyncMock import pytest -from llm_engine_server.core.domain_exceptions import ReadOnlyDatabaseException -from llm_engine_server.db.models import BatchJob, Bundle -from llm_engine_server.domain.entities import BatchJobRecord -from llm_engine_server.infra.repositories.db_batch_job_record_repository import ( +from model_engine_server.db.models import BatchJob, Bundle +from model_engine_server.domain.entities import BatchJobRecord +from model_engine_server.domain.exceptions import ReadOnlyDatabaseException +from model_engine_server.infra.repositories.db_batch_job_record_repository import ( DbBatchJobRecordRepository, OrmBatchJob, ) diff --git a/server/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py b/model-engine/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py similarity index 94% rename from server/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py rename to model-engine/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py index 81d172ad..2bfaab3b 100644 --- a/server/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_db_docker_image_batch_job_bundle_repository.py @@ -3,16 +3,18 @@ from unittest.mock import AsyncMock import pytest -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.core.domain_exceptions import ReadOnlyDatabaseException -from llm_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle -from llm_engine_server.domain.entities import GpuType -from llm_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.db.models import DockerImageBatchJobBundle as OrmDockerImageBatchJobBundle +from model_engine_server.domain.entities import GpuType +from model_engine_server.domain.entities.docker_image_batch_job_bundle_entity import ( DockerImageBatchJobBundle, ) -from llm_engine_server.domain.exceptions import CorruptRecordInfraStateException -from llm_engine_server.infra.repositories import DbDockerImageBatchJobBundleRepository -from llm_engine_server.infra.repositories.db_docker_image_batch_job_bundle_repository import ( +from model_engine_server.domain.exceptions import ( + CorruptRecordInfraStateException, + ReadOnlyDatabaseException, +) +from model_engine_server.infra.repositories import DbDockerImageBatchJobBundleRepository +from model_engine_server.infra.repositories.db_docker_image_batch_job_bundle_repository import ( translate_docker_image_batch_job_bundle_orm_to_entity, ) from sqlalchemy.ext.asyncio import AsyncSession @@ -106,7 +108,6 @@ async def test_list_docker_image_batch_job_bundles( test_api_key: str, test_api_key_team: str, ): - orm_docker_image_batch_job_bundle_1_v2.created_by = test_api_key_team orm_docker_image_batch_job_bundle_1_v2.owner = test_api_key_team docker_image_batch_job_bundle_1_v2.created_by = test_api_key_team diff --git a/server/tests/unit/infra/repositories/test_db_model_bundle_repository.py b/model-engine/tests/unit/infra/repositories/test_db_model_bundle_repository.py similarity index 97% rename from server/tests/unit/infra/repositories/test_db_model_bundle_repository.py rename to model-engine/tests/unit/infra/repositories/test_db_model_bundle_repository.py index 0e45d4a4..4eb94d20 100644 --- a/server/tests/unit/infra/repositories/test_db_model_bundle_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_db_model_bundle_repository.py @@ -3,16 +3,16 @@ from unittest.mock import AsyncMock import pytest -from llm_engine_server.common.dtos.model_bundles import ModelBundleOrderBy -from llm_engine_server.core.domain_exceptions import ReadOnlyDatabaseException -from llm_engine_server.db.models import Bundle -from llm_engine_server.domain.entities import ( +from model_engine_server.common.dtos.model_bundles import ModelBundleOrderBy +from model_engine_server.db.models import Bundle +from model_engine_server.domain.entities import ( CloudpickleArtifactFlavor, ModelBundle, ModelBundlePackagingType, PytorchFramework, ) -from llm_engine_server.infra.repositories.db_model_bundle_repository import ( +from model_engine_server.domain.exceptions import ReadOnlyDatabaseException +from model_engine_server.infra.repositories.db_model_bundle_repository import ( DbModelBundleRepository, OrmModelBundle, ) diff --git a/server/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py b/model-engine/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py similarity index 95% rename from server/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py rename to model-engine/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py index 38103bc2..3ad72127 100644 --- a/server/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_db_model_endpoint_record_repository.py @@ -3,13 +3,13 @@ from unittest.mock import AsyncMock, Mock import pytest -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.core.domain_exceptions import ReadOnlyDatabaseException -from llm_engine_server.db.models import Bundle, Endpoint -from llm_engine_server.domain.entities import ModelEndpointRecord -from llm_engine_server.infra.gateways import FakeMonitoringMetricsGateway -from llm_engine_server.infra.repositories import db_model_endpoint_record_repository -from llm_engine_server.infra.repositories.db_model_endpoint_record_repository import ( +from model_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy +from model_engine_server.db.models import Bundle, Endpoint +from model_engine_server.domain.entities import ModelEndpointRecord +from model_engine_server.domain.exceptions import ReadOnlyDatabaseException +from model_engine_server.infra.gateways import FakeMonitoringMetricsGateway +from model_engine_server.infra.repositories import db_model_endpoint_record_repository +from model_engine_server.infra.repositories.db_model_endpoint_record_repository import ( DbModelEndpointRecordRepository, OrmModelEndpoint, ) @@ -140,7 +140,7 @@ async def test_list_llm_model_endpoint_records( orm_model_bundle: Bundle, fake_monitoring_metrics_gateway: FakeMonitoringMetricsGateway, ): - filter_content = "endpoint_metadata ? '_llm' AND llm_engine.endpoints.name = :name_1 AND (llm_engine.endpoints.owner = :owner_1 OR llm_engine.endpoints.public_inference = true)" + filter_content = "endpoint_metadata ? '_llm' AND hosted_model_inference.endpoints.name = :name_1 AND (hosted_model_inference.endpoints.owner = :owner_1 OR hosted_model_inference.endpoints.public_inference = true)" def mock_llm_model_endpoint_select_all_by_filters( session: AsyncSession, filters: Any @@ -169,7 +169,7 @@ def mock_llm_model_endpoint_select_all_by_filters( order_by=ModelEndpointOrderBy.NEWEST, ) - filter_content = "endpoint_metadata ? '_llm' AND (llm_engine.endpoints.owner = :owner_1 OR llm_engine.endpoints.public_inference = true)" + filter_content = "endpoint_metadata ? '_llm' AND (hosted_model_inference.endpoints.owner = :owner_1 OR hosted_model_inference.endpoints.public_inference = true)" await repo.list_llm_model_endpoint_records( owner="test_user_id", name=None, diff --git a/model-engine/tests/unit/infra/repositories/test_live_tokenizer_repository.py b/model-engine/tests/unit/infra/repositories/test_live_tokenizer_repository.py new file mode 100644 index 00000000..b82d78f4 --- /dev/null +++ b/model-engine/tests/unit/infra/repositories/test_live_tokenizer_repository.py @@ -0,0 +1,62 @@ +from typing import Any, List +from unittest import mock + +import pytest +from model_engine_server.infra.repositories.live_tokenizer_repository import ( + LiveTokenizerRepository, + ModelInfo, +) + + +@pytest.fixture +def tokenizer_repository(fake_llm_artifact_gateway): + repository = LiveTokenizerRepository(fake_llm_artifact_gateway) + return repository + + +def mocked_get_models_s3_uri(*args, **kwargs): # noqa + return f"s3://fake-bucket/{args[0]}/{args[1]}" + + +def mocked_auto_tokenizer_from_pretrained(*args, **kwargs): # noqa + class mocked_encode: + def encode(self, input: str) -> List[Any]: + return [1] * len(input) + + return mocked_encode() + + +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.SUPPORTED_MODELS_INFO", + {"llama-7b": ModelInfo("llama-7b", None)}, +) +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.list_repo_refs", + lambda *args, **kwargs: None, # noqa +) +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.AutoTokenizer.from_pretrained", + mocked_auto_tokenizer_from_pretrained, +) +def test_load_tokenizer_from_hf(tokenizer_repository): + tokenizer = tokenizer_repository.load_tokenizer("llama-7b") + + assert tokenizer.encode("fake input") == [1] * len("fake input") + + +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.SUPPORTED_MODELS_INFO", + {"llama-7b": ModelInfo(None, "llama-7b")}, +) +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.get_models_s3_uri", + mocked_get_models_s3_uri, +) +@mock.patch( + "model_engine_server.infra.repositories.live_tokenizer_repository.AutoTokenizer.from_pretrained", + mocked_auto_tokenizer_from_pretrained, +) +def test_load_tokenizer_from_s3(tokenizer_repository): + tokenizer = tokenizer_repository.load_tokenizer("llama-7b") + + assert tokenizer.encode("fake input") == [1] * len("fake input") diff --git a/server/tests/unit/infra/repositories/test_redis_feature_flag_repository.py b/model-engine/tests/unit/infra/repositories/test_redis_feature_flag_repository.py similarity index 88% rename from server/tests/unit/infra/repositories/test_redis_feature_flag_repository.py rename to model-engine/tests/unit/infra/repositories/test_redis_feature_flag_repository.py index 50871f6e..5bf3a0e5 100644 --- a/server/tests/unit/infra/repositories/test_redis_feature_flag_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_redis_feature_flag_repository.py @@ -2,7 +2,7 @@ import aioredis import pytest -from llm_engine_server.infra.repositories.redis_feature_flag_repository import ( +from model_engine_server.infra.repositories.redis_feature_flag_repository import ( RedisFeatureFlagRepository, ) diff --git a/server/tests/unit/infra/repositories/test_redis_model_endpoint_cache_repository.py b/model-engine/tests/unit/infra/repositories/test_redis_model_endpoint_cache_repository.py similarity index 92% rename from server/tests/unit/infra/repositories/test_redis_model_endpoint_cache_repository.py rename to model-engine/tests/unit/infra/repositories/test_redis_model_endpoint_cache_repository.py index f7cdb743..eb1133fb 100644 --- a/server/tests/unit/infra/repositories/test_redis_model_endpoint_cache_repository.py +++ b/model-engine/tests/unit/infra/repositories/test_redis_model_endpoint_cache_repository.py @@ -2,7 +2,7 @@ import aioredis import pytest -from llm_engine_server.infra.repositories.redis_model_endpoint_cache_repository import ( +from model_engine_server.infra.repositories.redis_model_endpoint_cache_repository import ( RedisModelEndpointCacheRepository, ) diff --git a/server/tests/unit/infra/services/conftest.py b/model-engine/tests/unit/infra/services/conftest.py similarity index 84% rename from server/tests/unit/infra/services/conftest.py rename to model-engine/tests/unit/infra/services/conftest.py index a8c2edb6..873eea9e 100644 --- a/server/tests/unit/infra/services/conftest.py +++ b/model-engine/tests/unit/infra/services/conftest.py @@ -1,10 +1,10 @@ import pytest -from llm_engine_server.domain.entities import ModelBundle, ModelEndpoint -from llm_engine_server.infra.gateways import ( +from model_engine_server.domain.entities import ModelBundle, ModelEndpoint +from model_engine_server.infra.gateways import ( LiveBatchJobProgressGateway, LiveModelEndpointsSchemaGateway, ) -from llm_engine_server.infra.services import LiveBatchJobService, LiveModelEndpointService +from model_engine_server.infra.services import LiveBatchJobService, LiveModelEndpointService @pytest.fixture @@ -16,6 +16,7 @@ def fake_live_model_endpoint_service( fake_async_model_endpoint_inference_gateway, fake_streaming_model_endpoint_inference_gateway, fake_sync_model_endpoint_inference_gateway, + fake_inference_autoscaling_metrics_gateway, fake_filesystem_gateway, model_bundle_1: ModelBundle, model_bundle_2: ModelBundle, @@ -37,7 +38,9 @@ def fake_live_model_endpoint_service( async_model_endpoint_inference_gateway=fake_async_model_endpoint_inference_gateway, streaming_model_endpoint_inference_gateway=fake_streaming_model_endpoint_inference_gateway, sync_model_endpoint_inference_gateway=fake_sync_model_endpoint_inference_gateway, + inference_autoscaling_metrics_gateway=fake_inference_autoscaling_metrics_gateway, model_endpoints_schema_gateway=model_endpoints_schema_gateway, + can_scale_http_endpoint_from_zero_flag=True, # reasonable default, gets overridden in individual tests if needed ) return service diff --git a/model-engine/tests/unit/infra/services/test_docker_image_batch_job_llm_fine_tuning_service.py b/model-engine/tests/unit/infra/services/test_docker_image_batch_job_llm_fine_tuning_service.py new file mode 100644 index 00000000..598b5c1b --- /dev/null +++ b/model-engine/tests/unit/infra/services/test_docker_image_batch_job_llm_fine_tuning_service.py @@ -0,0 +1,65 @@ +import pytest +import pytest_asyncio +from model_engine_server.domain.entities.llm_fine_tune_entity import LLMFineTuneTemplate +from model_engine_server.infra.services import DockerImageBatchJobLLMFineTuningService + + +@pytest_asyncio.fixture +async def live_docker_image_batch_job_llm_fine_tuning_service( + fake_docker_image_batch_job_gateway, + fake_docker_image_batch_job_bundle_repository, + fake_llm_fine_tune_repository, +): + fake_bundle = ( + await fake_docker_image_batch_job_bundle_repository.create_docker_image_batch_job_bundle( + name="fake_fine_tune_bundle", + created_by="fake_egp_admin", + owner="fake_egp_admin", + image_repository="fake_image_repo", + image_tag="fake_image_tag", + command=["fake_command"], + env={"fake_env": "fake_env"}, + mount_location="/fake_mount_location", + cpus="1", + memory="0.1Gi", + storage="1Gi", + gpus=0, + gpu_type=None, + public=True, + ) + ) + await fake_llm_fine_tune_repository.write_job_template_for_model( + model_name="fake_model_name", + fine_tuning_method="fake_fine_tuning_method", + job_template=LLMFineTuneTemplate( + docker_image_batch_job_bundle_id=fake_bundle.id, + launch_endpoint_config={}, + default_hparams={}, + required_params=[], + ), + ) + return DockerImageBatchJobLLMFineTuningService( + docker_image_batch_job_gateway=fake_docker_image_batch_job_gateway, + docker_image_batch_job_bundle_repo=fake_docker_image_batch_job_bundle_repository, + llm_fine_tune_repository=fake_llm_fine_tune_repository, + ) + + +@pytest.mark.asyncio +async def test_create_fine_tune_success( + live_docker_image_batch_job_llm_fine_tuning_service, + fake_docker_image_batch_job_gateway, +): + batch_job_id = await live_docker_image_batch_job_llm_fine_tuning_service.create_fine_tune( + created_by="fake_user", + owner="fake_user", + model="fake_model_name", + training_file="fake_training_file_path", + validation_file="fake_validation_file_path", + fine_tuning_method="fake_fine_tuning_method", + hyperparameters={}, + fine_tuned_model="fake_fine_tuned_model_name", + wandb_config=None, + ) + assert batch_job_id is not None + assert fake_docker_image_batch_job_gateway.get_docker_image_batch_job(batch_job_id) is not None diff --git a/model-engine/tests/unit/infra/services/test_image_cache_service.py b/model-engine/tests/unit/infra/services/test_image_cache_service.py new file mode 100644 index 00000000..5f3bb72d --- /dev/null +++ b/model-engine/tests/unit/infra/services/test_image_cache_service.py @@ -0,0 +1,80 @@ +from typing import Any + +import pytest +from model_engine_server.common.config import hmi_config +from model_engine_server.common.env_vars import GIT_TAG +from model_engine_server.core.config import infra_config +from model_engine_server.infra.services.image_cache_service import DockerImage, ImageCacheService + + +@pytest.mark.asyncio +async def test_image_cache_success( + fake_image_cache_service: ImageCacheService, + model_endpoint_1, + model_endpoint_2, + model_endpoint_3, + model_endpoint_4, +): + infra_states = { + model_endpoint_1.record.id: (bool, model_endpoint_1.infra_state), + model_endpoint_2.record.id: (bool, model_endpoint_2.infra_state), + model_endpoint_3.record.id: (bool, model_endpoint_3.infra_state), + model_endpoint_4.record.id: (bool, model_endpoint_4.infra_state), + } + repo: Any = fake_image_cache_service.model_endpoint_record_repository + repo.add_model_endpoint_record(model_endpoint_1.record) + repo.add_model_endpoint_record(model_endpoint_2.record) + repo.add_model_endpoint_record(model_endpoint_3.record) + repo.add_model_endpoint_record(model_endpoint_4.record) + + await fake_image_cache_service.execute(infra_states) # type: ignore + gateway: Any = fake_image_cache_service.image_cache_gateway + + assert ( + f"{infra_config().ml_account_id}.dkr.ecr.us-west-2.amazonaws.com/my-repo:abcdefg222" + in gateway.cached_images["t4"] + ) + assert ( + f"{infra_config().ml_account_id}.dkr.ecr.us-west-2.amazonaws.com/my-repo:abcdefg111111111" + in gateway.cached_images["t4"] + ) + assert ( + f"{infra_config().ml_account_id}.dkr.ecr.us-west-2.amazonaws.com/my-repo:abcdefg00000" + in gateway.cached_images["t4"] + ) + + +@pytest.mark.asyncio +async def test_caching_finetune_llm_images( + fake_image_cache_service: ImageCacheService, +): + await fake_image_cache_service.execute({}) + gateway: Any = fake_image_cache_service.image_cache_gateway + + istio_image = DockerImage("gcr.io/istio-release/proxyv2", "1.15.0") + tgi_image_110 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.tgi_repository}", "1.1.0" + ) + vllm_image_027 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.2.7" + ) + vllm_image_032 = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.vllm_repository}", "0.3.2" + ) + latest_tag = "fake_docker_repository_latest_image_tag" + vllm_batch_image_latest = DockerImage( + f"{infra_config().docker_repo_prefix}/{hmi_config.batch_inference_vllm_repository}", + latest_tag, + ) + forwarder_image = DockerImage(f"{infra_config().docker_repo_prefix}/model-engine", GIT_TAG) + + for key in ["a10", "a100", "h100", "h100_3g40gb", "h100_1g20gb"]: + for llm_image in [ + istio_image, + tgi_image_110, + vllm_image_027, + vllm_image_032, + vllm_batch_image_latest, + forwarder_image, + ]: + assert f"{llm_image.repo}:{llm_image.tag}" in gateway.cached_images[key] diff --git a/server/tests/unit/infra/services/test_live_batch_job_orchestration_service.py b/model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py similarity index 93% rename from server/tests/unit/infra/services/test_live_batch_job_orchestration_service.py rename to model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py index 9b3ff377..7f80b4a1 100644 --- a/server/tests/unit/infra/services/test_live_batch_job_orchestration_service.py +++ b/model-engine/tests/unit/infra/services/test_live_batch_job_orchestration_service.py @@ -4,10 +4,9 @@ from unittest.mock import patch import pytest -from llm_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME -from llm_engine_server.common.dtos.tasks import GetAsyncTaskV1Response, ResponseSchema, TaskStatus -from llm_engine_server.core.domain_exceptions import ObjectNotFoundException -from llm_engine_server.domain.entities import ( +from model_engine_server.common.constants import DEFAULT_CELERY_TASK_NAME, LIRA_CELERY_TASK_NAME +from model_engine_server.common.dtos.tasks import GetAsyncTaskV1Response, ResponseSchema, TaskStatus +from model_engine_server.domain.entities import ( BatchJob, BatchJobSerializationFormat, BatchJobStatus, @@ -15,12 +14,13 @@ ModelEndpoint, ModelEndpointStatus, ) -from llm_engine_server.infra.gateways import LiveBatchJobProgressGateway -from llm_engine_server.infra.services import ( +from model_engine_server.domain.exceptions import ObjectNotFoundException +from model_engine_server.infra.gateways import LiveBatchJobProgressGateway +from model_engine_server.infra.services import ( LiveBatchJobOrchestrationService, LiveModelEndpointService, ) -from llm_engine_server.infra.services.live_batch_job_orchestration_service import ( +from model_engine_server.infra.services.live_batch_job_orchestration_service import ( BatchEndpointInferencePredictionResponse, BatchEndpointInProgressTask, ) @@ -50,9 +50,9 @@ def live_batch_job_orchestration_service( assert model_endpoint_1.infra_state is not None assert model_endpoint_runnable.infra_state is not None gateway.db[model_endpoint_1.infra_state.deployment_name] = model_endpoint_1.infra_state - gateway.db[ - model_endpoint_runnable.infra_state.deployment_name - ] = model_endpoint_runnable.infra_state + gateway.db[model_endpoint_runnable.infra_state.deployment_name] = ( + model_endpoint_runnable.infra_state + ) return LiveBatchJobOrchestrationService( model_endpoint_service=fake_live_model_endpoint_service, batch_job_record_repository=fake_batch_job_record_repository, @@ -235,7 +235,7 @@ async def test_run_batch_job_wait_for_endpoint( ): model_endpoint_1.record.status = ModelEndpointStatus.UPDATE_PENDING with patch( - "llm_engine_server.infra.services.live_batch_job_orchestration_service.asyncio.sleep" + "model_engine_server.infra.services.live_batch_job_orchestration_service.asyncio.sleep" ) as mock_sleep: def set_record_ready(*args, **kwargs): diff --git a/server/tests/unit/infra/services/test_live_batch_job_service.py b/model-engine/tests/unit/infra/services/test_live_batch_job_service.py similarity index 94% rename from server/tests/unit/infra/services/test_live_batch_job_service.py rename to model-engine/tests/unit/infra/services/test_live_batch_job_service.py index b1b60a91..8d440a2a 100644 --- a/server/tests/unit/infra/services/test_live_batch_job_service.py +++ b/model-engine/tests/unit/infra/services/test_live_batch_job_service.py @@ -1,9 +1,9 @@ import pytest -from llm_engine_server.common.dtos.batch_jobs import CreateBatchJobResourceRequests -from llm_engine_server.domain.entities import BatchJobSerializationFormat, GpuType, ModelBundle -from llm_engine_server.domain.exceptions import EndpointResourceInvalidRequestException -from llm_engine_server.infra.services import LiveBatchJobService -from llm_engine_server.infra.services.live_batch_job_service import ( +from model_engine_server.common.dtos.batch_jobs import CreateBatchJobResourceRequests +from model_engine_server.domain.entities import BatchJobSerializationFormat, GpuType, ModelBundle +from model_engine_server.domain.exceptions import EndpointResourceInvalidRequestException +from model_engine_server.infra.services import LiveBatchJobService +from model_engine_server.infra.services.live_batch_job_service import ( DEFAULT_ENDPOINT_CPUS_BATCH_JOB, DEFAULT_ENDPOINT_GPU_TYPE_BATCH_JOB, DEFAULT_ENDPOINT_GPUS_BATCH_JOB, diff --git a/server/tests/unit/infra/services/test_live_endpoint_builder_service.py b/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py similarity index 85% rename from server/tests/unit/infra/services/test_live_endpoint_builder_service.py rename to model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py index d87be073..a0e876eb 100644 --- a/server/tests/unit/infra/services/test_live_endpoint_builder_service.py +++ b/model-engine/tests/unit/infra/services/test_live_endpoint_builder_service.py @@ -2,22 +2,27 @@ from unittest.mock import Mock, mock_open import pytest -from llm_engine_server.common.dtos.docker_repository import BuildImageResponse -from llm_engine_server.common.dtos.endpoint_builder import ( +from model_engine_server.common.dtos.docker_repository import BuildImageResponse +from model_engine_server.common.dtos.endpoint_builder import ( BuildEndpointRequest, BuildEndpointResponse, BuildEndpointStatus, ) -from llm_engine_server.core.domain_exceptions import DockerBuildFailedException -from llm_engine_server.core.fake_notification_gateway import FakeNotificationGateway -from llm_engine_server.core.notification_gateway import NotificationApp -from llm_engine_server.domain.entities.model_bundle_entity import RunnableImageFlavor -from llm_engine_server.domain.exceptions import EndpointResourceInfraException -from llm_engine_server.infra.gateways.fake_monitoring_metrics_gateway import ( +from model_engine_server.core.fake_notification_gateway import FakeNotificationGateway +from model_engine_server.core.notification_gateway import NotificationApp +from model_engine_server.domain.entities.model_bundle_entity import ( + ArtifactLike, + RunnableImageFlavor, +) +from model_engine_server.domain.exceptions import ( + DockerBuildFailedException, + EndpointResourceInfraException, +) +from model_engine_server.infra.gateways.fake_monitoring_metrics_gateway import ( FakeMonitoringMetricsGateway, ) -from llm_engine_server.infra.repositories import ModelEndpointCacheRepository -from llm_engine_server.infra.services import ( +from model_engine_server.infra.repositories import ModelEndpointCacheRepository +from model_engine_server.infra.services import ( LiveEndpointBuilderService, live_endpoint_builder_service, ) @@ -97,8 +102,11 @@ def set_env_vars(): live_endpoint_builder_service.ECR_AWS_PROFILE = "default" live_endpoint_builder_service.GIT_TAG = "test_tag" live_endpoint_builder_service.ENV = "test_env" + live_endpoint_builder_service.WORKSPACE_PATH = ".." live_endpoint_builder_service.open = mock_open() live_endpoint_builder_service.os.mkdir = Mock() + live_endpoint_builder_service.open_wrapper = mock_open() + live_endpoint_builder_service.tempfile.mkstemp = Mock(return_value=["", ""]) @pytest.mark.asyncio @@ -109,6 +117,7 @@ async def test_build_endpoint( build_endpoint_request_async_runnable_image: BuildEndpointRequest, build_endpoint_request_sync_runnable_image: BuildEndpointRequest, build_endpoint_request_streaming_runnable_image: BuildEndpointRequest, + build_endpoint_request_async_zipartifact_highpri: BuildEndpointRequest, endpoint_builder_service_empty_docker_built: LiveEndpointBuilderService, endpoint_builder_service_empty_docker_not_built: LiveEndpointBuilderService, fake_model_endpoint_cache_repository: ModelEndpointCacheRepository, @@ -126,10 +135,12 @@ async def test_build_endpoint( build_endpoint_request_async_runnable_image, build_endpoint_request_sync_runnable_image, build_endpoint_request_streaming_runnable_image, + build_endpoint_request_async_zipartifact_highpri, ]: fake_monitoring_metrics_gateway.reset() repo.add_model_endpoint_record(request.model_endpoint_record) - response = await service.build_endpoint(request) + # Pass in a deep copy of request since LiveEndpointBuilderService.convert_artifact_like_bundle_to_runnable_image mutate the request + response = await service.build_endpoint(request.copy(deep=True)) assert response == BuildEndpointResponse(status=BuildEndpointStatus.OK) assert fake_model_endpoint_cache_repository.read_endpoint_info( endpoint_id=request.model_endpoint_record.id, @@ -138,6 +149,14 @@ async def test_build_endpoint( assert fake_monitoring_metrics_gateway.attempted_build == 1 assert fake_monitoring_metrics_gateway.docker_failed_build == 0 assert fake_monitoring_metrics_gateway.successful_build == 1 + assert fake_monitoring_metrics_gateway.build_time_seconds > 0 + if isinstance(request.model_endpoint_record.current_model_bundle.flavor, ArtifactLike): + if service == endpoint_builder_service_empty_docker_built: + assert sum(fake_monitoring_metrics_gateway.image_build_cache_hit.values()) > 0 + assert sum(fake_monitoring_metrics_gateway.image_build_cache_miss.values()) == 0 + else: + assert sum(fake_monitoring_metrics_gateway.image_build_cache_hit.values()) == 0 + assert sum(fake_monitoring_metrics_gateway.image_build_cache_miss.values()) > 0 @pytest.mark.asyncio @@ -199,7 +218,7 @@ async def test_build_endpoint_build_result_failed_yields_docker_build_failed_exc repo.add_model_endpoint_record(build_endpoint_request_sync_pytorch.model_endpoint_record) endpoint_builder_service_empty_docker_not_built.docker_repository.__setattr__( "build_image", - Mock(return_value=BuildImageResponse(status=False, logs="")), + Mock(return_value=BuildImageResponse(status=False, logs="", job_name="")), ) with pytest.raises(DockerBuildFailedException): await endpoint_builder_service_empty_docker_not_built.build_endpoint( diff --git a/server/tests/unit/infra/services/test_live_model_endpoint_service.py b/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py similarity index 94% rename from server/tests/unit/infra/services/test_live_model_endpoint_service.py rename to model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py index 8b37b5a9..b67fc4cc 100644 --- a/server/tests/unit/infra/services/test_live_model_endpoint_service.py +++ b/model-engine/tests/unit/infra/services/test_live_model_endpoint_service.py @@ -2,21 +2,19 @@ from unittest.mock import AsyncMock import pytest -from llm_engine_server.core.domain_exceptions import ( - ObjectAlreadyExistsException, - ObjectNotFoundException, -) -from llm_engine_server.domain.entities import ( +from model_engine_server.domain.entities import ( ModelBundle, ModelEndpoint, ModelEndpointRecord, ModelEndpointStatus, ) -from llm_engine_server.domain.exceptions import ( +from model_engine_server.domain.exceptions import ( EndpointDeleteFailedException, ExistingEndpointOperationInProgressException, + ObjectAlreadyExistsException, + ObjectNotFoundException, ) -from llm_engine_server.infra.services import LiveModelEndpointService +from model_engine_server.infra.services import LiveModelEndpointService async def _create_model_endpoint_helper( @@ -54,6 +52,7 @@ async def _create_model_endpoint_helper( memory=infra_state.resource_state.memory, gpu_type=infra_state.resource_state.gpu_type, storage=infra_state.resource_state.storage, + nodes_per_worker=infra_state.resource_state.nodes_per_worker, optimize_costs=bool(infra_state.resource_state.optimize_costs), min_workers=infra_state.deployment_state.min_workers, max_workers=infra_state.deployment_state.max_workers, @@ -63,6 +62,7 @@ async def _create_model_endpoint_helper( results_s3_bucket=infra_state.results_s3_bucket, prewarm=prewarm, high_priority=high_priority, + billing_tags=infra_state.user_config_state.endpoint_config.billing_tags, owner=model_endpoint.record.owner, ) return model_endpoint_record @@ -112,6 +112,7 @@ async def test_create_get_model_endpoint_success( model_endpoint.record.created_at = model_endpoint_1.record.created_at model_endpoint.record.last_updated_at = model_endpoint_1.record.last_updated_at model_endpoint.record.id = model_endpoint_1.record.id + model_endpoint.infra_state.user_config_state.endpoint_config.billing_tags = model_endpoint_1.infra_state.user_config_state.endpoint_config.billing_tags # type: ignore # Use dict comparison because errors are more readable. assert model_endpoint.dict() == model_endpoint_1.dict() @@ -156,6 +157,7 @@ async def test_create_model_endpoint_raises_already_exists( memory=infra_state.resource_state.memory, gpu_type=infra_state.resource_state.gpu_type, storage=infra_state.resource_state.storage, + nodes_per_worker=infra_state.resource_state.nodes_per_worker, optimize_costs=bool(infra_state.resource_state.optimize_costs), min_workers=infra_state.deployment_state.min_workers, max_workers=infra_state.deployment_state.max_workers, @@ -238,6 +240,17 @@ async def test_create_update_model_endpoint_success( assert model_endpoint.infra_state.deployment_state.max_workers == update_kwargs["max_workers"] assert model_endpoint.infra_state.labels == update_kwargs["labels"] + # Now update min_worker only + update_kwargs: Any = dict( + min_workers=2, + ) + updated_model_endpoint_record = await fake_live_model_endpoint_service.update_model_endpoint( + model_endpoint_id=model_endpoint_record.id, **update_kwargs + ) + + # Make sure metadata is not updated + assert updated_model_endpoint_record.metadata == {"some_new_key": "some_new_values"} + @pytest.mark.skip(reason="Exception is temporarily disabled due to lock flakiness") @pytest.mark.asyncio diff --git a/server/tests/unit/infra/services/test_model_endpoint_cache_service.py b/model-engine/tests/unit/infra/services/test_model_endpoint_cache_service.py similarity index 92% rename from server/tests/unit/infra/services/test_model_endpoint_cache_service.py rename to model-engine/tests/unit/infra/services/test_model_endpoint_cache_service.py index 64f50eaa..fc3661b2 100644 --- a/server/tests/unit/infra/services/test_model_endpoint_cache_service.py +++ b/model-engine/tests/unit/infra/services/test_model_endpoint_cache_service.py @@ -1,5 +1,5 @@ import pytest -from llm_engine_server.infra.services.model_endpoint_cache_service import ( +from model_engine_server.infra.services.model_endpoint_cache_service import ( ModelEndpointCacheWriteService, ) @@ -18,9 +18,7 @@ async def test_model_endpoint_write_success( ) cache_write_service = ModelEndpointCacheWriteService( - fake_model_endpoint_cache_repository, - fake_resource_gateway, - fake_image_cache_service, + fake_model_endpoint_cache_repository, fake_resource_gateway, fake_image_cache_service ) await cache_write_service.execute(42) infra_state = await fake_model_endpoint_cache_repository.read_endpoint_info( diff --git a/requirements-dev.txt b/requirements-dev.txt index f293cb22..5e673c87 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -1,7 +1,10 @@ # Make sure to update .pre-commit-config.yaml to match versions! -black==22.12.0 -ruff==0.0.278 -isort==5.12.0 -mypy==1.3.0 -pip-tools==7.0.0 -poetry==1.5.1 +black[jupyter]==24.8.0 +datamodel-code-generator>=0.25.8 +ruff==0.6.8 +ipython==8.12.0 # 8.12.0 is the last version to support Python 3.8 +isort==5.13.2 +mypy==1.11.2 +pip-tools==7.4.1 +poetry==1.8.2 +pre-commit==3.8.0 diff --git a/requirements-docs.txt b/requirements-docs.txt index 51d81c23..01a02a52 100644 --- a/requirements-docs.txt +++ b/requirements-docs.txt @@ -5,8 +5,9 @@ mkdocs-material-extensions==1.1.1 mkdocs-render-swagger-plugin~=0.0.4 mkdocs-simple-hooks~=0.1.5 mkdocs-video~=1.5.0 -mkdocstrings[python]~=0.20.0 -pydantic~=1.10.0 +mkdocstrings[python]~=0.24.0 +pydantic==2.8.2 +griffe<1.0 neoteroi-mkdocs~=1.0.0 tabulate~=0.9.0 scale-llm-engine \ No newline at end of file diff --git a/scripts/generate-openai-types.sh b/scripts/generate-openai-types.sh new file mode 100755 index 00000000..e787cfe8 --- /dev/null +++ b/scripts/generate-openai-types.sh @@ -0,0 +1,53 @@ +#!/usr/bin/env bash + +SCRIPT_DIR=$( cd -- "$( dirname -- "${BASH_SOURCE[0]}" )" &> /dev/null && pwd ) +BASE_DIR=${SCRIPT_DIR}/.. + +DEST_DIR=${BASE_DIR}/model-engine/model_engine_server/common/types/gen +OPENAI_SPEC=${SCRIPT_DIR}/openai-spec.yaml + +# Generate OpenAPI types for server +datamodel-codegen \ + --input ${OPENAI_SPEC} \ + --input-file-type openapi \ + --output ${DEST_DIR}/openai.py \ + --output-model-type pydantic_v2.BaseModel \ + --enum-field-as-literal all \ + --field-constraints \ + --strict-nullable \ + --use-annotated + +# replace pydantic import w/ our custom module to replace the AnyUrl types +# Pydantic AnyUrl is super problematic for various reasons +sed -i 's/^from pydantic import /from model_engine_server.common.pydantic_types import /' ${DEST_DIR}/openai.py + + +CLIENT_DIR=${BASE_DIR}/clients/python/llmengine/data_types/gen + +# Generate OpenAPI types for client +# Client is using pydantic v1 +datamodel-codegen \ + --input ${OPENAI_SPEC} \ + --input-file-type openapi \ + --output ${CLIENT_DIR}/openai.py \ + --output-model-type pydantic.BaseModel \ + --enum-field-as-literal all \ + --field-constraints \ + --strict-nullable \ + --use-annotated + +# Ignore mypy for this file +# I tried updating mypy.ini to ignore this file, but it didn't work +sed -i '1s/^/# mypy: ignore-errors\n/' ${CLIENT_DIR}/openai.py + +# Add conditional import for pydantic v1 and v2 +# replace line starting with 'from pydantic ' with the following multiline python code +# import pydantic +# PYDANTIC_V2 = hasattr(pydantic, "VERSION") and pydantic.VERSION.startswith("2.") +# +# if PYDANTIC_V2: +# from pydantic.v1 +# +# else: +# from pydantic +sed -i -E '/^from pydantic import /{s/^from pydantic import (.*)$/import pydantic\nPYDANTIC_V2 = hasattr(pydantic, "VERSION") and pydantic.VERSION.startswith("2.")\nif PYDANTIC_V2:\n from pydantic.v1 import \1 # noqa: F401\nelse:\n from pydantic import \1 # type: ignore # noqa: F401/}' ${CLIENT_DIR}/openai.py \ No newline at end of file diff --git a/scripts/openai-spec.yaml b/scripts/openai-spec.yaml new file mode 100644 index 00000000..01cbcde6 --- /dev/null +++ b/scripts/openai-spec.yaml @@ -0,0 +1,17072 @@ +# https://github.com/openai/openai-openapi/blob/423e672461b3d17f9829711e4a858e777252f077/openapi.yaml +openapi: 3.0.0 +info: + title: OpenAI API + description: The OpenAI REST API. Please see https://platform.openai.com/docs/api-reference for more details. + version: "2.3.0" + termsOfService: https://openai.com/policies/terms-of-use + contact: + name: OpenAI Support + url: https://help.openai.com/ + license: + name: MIT + url: https://github.com/openai/openai-openapi/blob/master/LICENSE +servers: + - url: https://api.openai.com/v1 +tags: + - name: Assistants + description: Build Assistants that can call models and use tools. + - name: Audio + description: Turn audio into text or text into audio. + - name: Chat + description: Given a list of messages comprising a conversation, the model will return a response. + - name: Completions + description: Given a prompt, the model will return one or more predicted completions, and can also return the probabilities of alternative tokens at each position. + - name: Embeddings + description: Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms. + - name: Fine-tuning + description: Manage fine-tuning jobs to tailor a model to your specific training data. + - name: Batch + description: Create large batches of API requests to run asynchronously. + - name: Files + description: Files are used to upload documents that can be used with features like Assistants and Fine-tuning. + - name: Uploads + description: Use Uploads to upload large files in multiple parts. + - name: Images + description: Given a prompt and/or an input image, the model will generate a new image. + - name: Models + description: List and describe the various models available in the API. + - name: Moderations + description: Given a input text, outputs if the model classifies it as potentially harmful. + - name: Audit Logs + description: List user actions and configuration changes within this organization. +paths: + # Note: When adding an endpoint, make sure you also add it in the `groups` section, in the end of this file, + # under the appropriate group + /chat/completions: + post: + operationId: createChatCompletion + tags: + - Chat + summary: Creates a model response for the given chat conversation. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateChatCompletionRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/CreateChatCompletionResponse" + + x-oaiMeta: + name: Create chat completion + group: chat + returns: | + Returns a [chat completion](/docs/api-reference/chat/object) object, or a streamed sequence of [chat completion chunk](/docs/api-reference/chat/streaming) objects if the request is streamed. + path: create + examples: + - title: Default + request: + curl: | + curl https://api.openai.com/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "VAR_model_id", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello!" + } + ] + }' + python: | + from openai import OpenAI + client = OpenAI() + + completion = client.chat.completions.create( + model="VAR_model_id", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"} + ] + ) + + print(completion.choices[0].message) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const completion = await openai.chat.completions.create({ + messages: [{ role: "system", content: "You are a helpful assistant." }], + model: "VAR_model_id", + }); + + console.log(completion.choices[0]); + } + + main(); + response: &chat_completion_example | + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o-mini", + "system_fingerprint": "fp_44709d6fcb", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "\n\nHello there, how may I assist you today?", + }, + "logprobs": null, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + - title: Image input + request: + curl: | + curl https://api.openai.com/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "text", + "text": "What'\''s in this image?" + }, + { + "type": "image_url", + "image_url": { + "url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg" + } + } + ] + } + ], + "max_tokens": 300 + }' + python: | + from openai import OpenAI + + client = OpenAI() + + response = client.chat.completions.create( + model="gpt-4o", + messages=[ + { + "role": "user", + "content": [ + {"type": "text", "text": "What's in this image?"}, + { + "type": "image_url", + "image_url": "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + }, + ], + } + ], + max_tokens=300, + ) + + print(response.choices[0]) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const response = await openai.chat.completions.create({ + model: "gpt-4o", + messages: [ + { + role: "user", + content: [ + { type: "text", text: "What's in this image?" }, + { + type: "image_url", + image_url: + "https://upload.wikimedia.org/wikipedia/commons/thumb/d/dd/Gfp-wisconsin-madison-the-nature-boardwalk.jpg/2560px-Gfp-wisconsin-madison-the-nature-boardwalk.jpg", + }, + ], + }, + ], + }); + console.log(response.choices[0]); + } + main(); + response: &chat_completion_image_example | + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-4o-mini", + "system_fingerprint": "fp_44709d6fcb", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "\n\nThis image shows a wooden boardwalk extending through a lush green marshland.", + }, + "logprobs": null, + "finish_reason": "stop" + }], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + - title: Streaming + request: + curl: | + curl https://api.openai.com/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "VAR_model_id", + "messages": [ + { + "role": "system", + "content": "You are a helpful assistant." + }, + { + "role": "user", + "content": "Hello!" + } + ], + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + completion = client.chat.completions.create( + model="VAR_model_id", + messages=[ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"} + ], + stream=True + ) + + for chunk in completion: + print(chunk.choices[0].delta) + + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const completion = await openai.chat.completions.create({ + model: "VAR_model_id", + messages: [ + {"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "Hello!"} + ], + stream: true, + }); + + for await (const chunk of completion) { + console.log(chunk.choices[0].delta.content); + } + } + + main(); + response: &chat_completion_chunk_example | + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]} + + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]} + + .... + + {"id":"chatcmpl-123","object":"chat.completion.chunk","created":1694268190,"model":"gpt-4o-mini", "system_fingerprint": "fp_44709d6fcb", "choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"stop"}]} + - title: Functions + request: + curl: | + curl https://api.openai.com/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "gpt-4o", + "messages": [ + { + "role": "user", + "content": "What'\''s the weather like in Boston today?" + } + ], + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + } + ], + "tool_choice": "auto" + }' + python: | + from openai import OpenAI + client = OpenAI() + + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ] + messages = [{"role": "user", "content": "What's the weather like in Boston today?"}] + completion = client.chat.completions.create( + model="VAR_model_id", + messages=messages, + tools=tools, + tool_choice="auto" + ) + + print(completion) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const messages = [{"role": "user", "content": "What's the weather like in Boston today?"}]; + const tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ]; + + const response = await openai.chat.completions.create({ + model: "gpt-4o", + messages: messages, + tools: tools, + tool_choice: "auto", + }); + + console.log(response); + } + + main(); + response: &chat_completion_function_example | + { + "id": "chatcmpl-abc123", + "object": "chat.completion", + "created": 1699896916, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": null, + "tool_calls": [ + { + "id": "call_abc123", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\n\"location\": \"Boston, MA\"\n}" + } + } + ] + }, + "logprobs": null, + "finish_reason": "tool_calls" + } + ], + "usage": { + "prompt_tokens": 82, + "completion_tokens": 17, + "total_tokens": 99 + } + } + - title: Logprobs + request: + curl: | + curl https://api.openai.com/v1/chat/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "VAR_model_id", + "messages": [ + { + "role": "user", + "content": "Hello!" + } + ], + "logprobs": true, + "top_logprobs": 2 + }' + python: | + from openai import OpenAI + client = OpenAI() + + completion = client.chat.completions.create( + model="VAR_model_id", + messages=[ + {"role": "user", "content": "Hello!"} + ], + logprobs=True, + top_logprobs=2 + ) + + print(completion.choices[0].message) + print(completion.choices[0].logprobs) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const completion = await openai.chat.completions.create({ + messages: [{ role: "user", content: "Hello!" }], + model: "VAR_model_id", + logprobs: true, + top_logprobs: 2, + }); + + console.log(completion.choices[0]); + } + + main(); + response: | + { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1702685778, + "model": "gpt-4o-mini", + "choices": [ + { + "index": 0, + "message": { + "role": "assistant", + "content": "Hello! How can I assist you today?" + }, + "logprobs": { + "content": [ + { + "token": "Hello", + "logprob": -0.31725305, + "bytes": [72, 101, 108, 108, 111], + "top_logprobs": [ + { + "token": "Hello", + "logprob": -0.31725305, + "bytes": [72, 101, 108, 108, 111] + }, + { + "token": "Hi", + "logprob": -1.3190403, + "bytes": [72, 105] + } + ] + }, + { + "token": "!", + "logprob": -0.02380986, + "bytes": [ + 33 + ], + "top_logprobs": [ + { + "token": "!", + "logprob": -0.02380986, + "bytes": [33] + }, + { + "token": " there", + "logprob": -3.787621, + "bytes": [32, 116, 104, 101, 114, 101] + } + ] + }, + { + "token": " How", + "logprob": -0.000054669687, + "bytes": [32, 72, 111, 119], + "top_logprobs": [ + { + "token": " How", + "logprob": -0.000054669687, + "bytes": [32, 72, 111, 119] + }, + { + "token": "<|end|>", + "logprob": -10.953937, + "bytes": null + } + ] + }, + { + "token": " can", + "logprob": -0.015801601, + "bytes": [32, 99, 97, 110], + "top_logprobs": [ + { + "token": " can", + "logprob": -0.015801601, + "bytes": [32, 99, 97, 110] + }, + { + "token": " may", + "logprob": -4.161023, + "bytes": [32, 109, 97, 121] + } + ] + }, + { + "token": " I", + "logprob": -3.7697225e-6, + "bytes": [ + 32, + 73 + ], + "top_logprobs": [ + { + "token": " I", + "logprob": -3.7697225e-6, + "bytes": [32, 73] + }, + { + "token": " assist", + "logprob": -13.596657, + "bytes": [32, 97, 115, 115, 105, 115, 116] + } + ] + }, + { + "token": " assist", + "logprob": -0.04571125, + "bytes": [32, 97, 115, 115, 105, 115, 116], + "top_logprobs": [ + { + "token": " assist", + "logprob": -0.04571125, + "bytes": [32, 97, 115, 115, 105, 115, 116] + }, + { + "token": " help", + "logprob": -3.1089056, + "bytes": [32, 104, 101, 108, 112] + } + ] + }, + { + "token": " you", + "logprob": -5.4385737e-6, + "bytes": [32, 121, 111, 117], + "top_logprobs": [ + { + "token": " you", + "logprob": -5.4385737e-6, + "bytes": [32, 121, 111, 117] + }, + { + "token": " today", + "logprob": -12.807695, + "bytes": [32, 116, 111, 100, 97, 121] + } + ] + }, + { + "token": " today", + "logprob": -0.0040071653, + "bytes": [32, 116, 111, 100, 97, 121], + "top_logprobs": [ + { + "token": " today", + "logprob": -0.0040071653, + "bytes": [32, 116, 111, 100, 97, 121] + }, + { + "token": "?", + "logprob": -5.5247097, + "bytes": [63] + } + ] + }, + { + "token": "?", + "logprob": -0.0008108172, + "bytes": [63], + "top_logprobs": [ + { + "token": "?", + "logprob": -0.0008108172, + "bytes": [63] + }, + { + "token": "?\n", + "logprob": -7.184561, + "bytes": [63, 10] + } + ] + } + ] + }, + "finish_reason": "stop" + } + ], + "usage": { + "prompt_tokens": 9, + "completion_tokens": 9, + "total_tokens": 18 + }, + "system_fingerprint": null + } + + /completions: + post: + operationId: createCompletion + tags: + - Completions + summary: Creates a completion for the provided prompt and parameters. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateCompletionRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/CreateCompletionResponse" + x-oaiMeta: + name: Create completion + group: completions + returns: | + Returns a [completion](/docs/api-reference/completions/object) object, or a sequence of completion objects if the request is streamed. + legacy: true + examples: + - title: No streaming + request: + curl: | + curl https://api.openai.com/v1/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "VAR_model_id", + "prompt": "Say this is a test", + "max_tokens": 7, + "temperature": 0 + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.completions.create( + model="VAR_model_id", + prompt="Say this is a test", + max_tokens=7, + temperature=0 + ) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const completion = await openai.completions.create({ + model: "VAR_model_id", + prompt: "Say this is a test.", + max_tokens: 7, + temperature: 0, + }); + + console.log(completion); + } + main(); + response: | + { + "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", + "object": "text_completion", + "created": 1589478378, + "model": "VAR_model_id", + "system_fingerprint": "fp_44709d6fcb", + "choices": [ + { + "text": "\n\nThis is indeed a test", + "index": 0, + "logprobs": null, + "finish_reason": "length" + } + ], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12 + } + } + - title: Streaming + request: + curl: | + curl https://api.openai.com/v1/completions \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "VAR_model_id", + "prompt": "Say this is a test", + "max_tokens": 7, + "temperature": 0, + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + for chunk in client.completions.create( + model="VAR_model_id", + prompt="Say this is a test", + max_tokens=7, + temperature=0, + stream=True + ): + print(chunk.choices[0].text) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const stream = await openai.completions.create({ + model: "VAR_model_id", + prompt: "Say this is a test.", + stream: true, + }); + + for await (const chunk of stream) { + console.log(chunk.choices[0].text) + } + } + main(); + response: | + { + "id": "cmpl-7iA7iJjj8V2zOkCGvWF2hAkDWBQZe", + "object": "text_completion", + "created": 1690759702, + "choices": [ + { + "text": "This", + "index": 0, + "logprobs": null, + "finish_reason": null + } + ], + "model": "gpt-3.5-turbo-instruct" + "system_fingerprint": "fp_44709d6fcb", + } + + /images/generations: + post: + operationId: createImage + tags: + - Images + summary: Creates an image given a prompt. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateImageRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ImagesResponse" + x-oaiMeta: + name: Create image + group: images + returns: Returns a list of [image](/docs/api-reference/images/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/images/generations \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "model": "dall-e-3", + "prompt": "A cute baby sea otter", + "n": 1, + "size": "1024x1024" + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.images.generate( + model="dall-e-3", + prompt="A cute baby sea otter", + n=1, + size="1024x1024" + ) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const image = await openai.images.generate({ model: "dall-e-3", prompt: "A cute baby sea otter" }); + + console.log(image.data); + } + main(); + response: | + { + "created": 1589478378, + "data": [ + { + "url": "https://..." + }, + { + "url": "https://..." + } + ] + } + /images/edits: + post: + operationId: createImageEdit + tags: + - Images + summary: Creates an edited or extended image given an original image and a prompt. + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#/components/schemas/CreateImageEditRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ImagesResponse" + x-oaiMeta: + name: Create image edit + group: images + returns: Returns a list of [image](/docs/api-reference/images/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/images/edits \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -F image="@otter.png" \ + -F mask="@mask.png" \ + -F prompt="A cute baby sea otter wearing a beret" \ + -F n=2 \ + -F size="1024x1024" + python: | + from openai import OpenAI + client = OpenAI() + + client.images.edit( + image=open("otter.png", "rb"), + mask=open("mask.png", "rb"), + prompt="A cute baby sea otter wearing a beret", + n=2, + size="1024x1024" + ) + node.js: |- + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const image = await openai.images.edit({ + image: fs.createReadStream("otter.png"), + mask: fs.createReadStream("mask.png"), + prompt: "A cute baby sea otter wearing a beret", + }); + + console.log(image.data); + } + main(); + response: | + { + "created": 1589478378, + "data": [ + { + "url": "https://..." + }, + { + "url": "https://..." + } + ] + } + /images/variations: + post: + operationId: createImageVariation + tags: + - Images + summary: Creates a variation of a given image. + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#/components/schemas/CreateImageVariationRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ImagesResponse" + x-oaiMeta: + name: Create image variation + group: images + returns: Returns a list of [image](/docs/api-reference/images/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/images/variations \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -F image="@otter.png" \ + -F n=2 \ + -F size="1024x1024" + python: | + from openai import OpenAI + client = OpenAI() + + response = client.images.create_variation( + image=open("image_edit_original.png", "rb"), + n=2, + size="1024x1024" + ) + node.js: |- + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const image = await openai.images.createVariation({ + image: fs.createReadStream("otter.png"), + }); + + console.log(image.data); + } + main(); + response: | + { + "created": 1589478378, + "data": [ + { + "url": "https://..." + }, + { + "url": "https://..." + } + ] + } + + /embeddings: + post: + operationId: createEmbedding + tags: + - Embeddings + summary: Creates an embedding vector representing the input text. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateEmbeddingRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/CreateEmbeddingResponse" + x-oaiMeta: + name: Create embeddings + group: embeddings + returns: A list of [embedding](/docs/api-reference/embeddings/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/embeddings \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "input": "The food was delicious and the waiter...", + "model": "text-embedding-ada-002", + "encoding_format": "float" + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.embeddings.create( + model="text-embedding-ada-002", + input="The food was delicious and the waiter...", + encoding_format="float" + ) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const embedding = await openai.embeddings.create({ + model: "text-embedding-ada-002", + input: "The quick brown fox jumped over the lazy dog", + encoding_format: "float", + }); + + console.log(embedding); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "object": "embedding", + "embedding": [ + 0.0023064255, + -0.009327292, + .... (1536 floats total for ada-002) + -0.0028842222, + ], + "index": 0 + } + ], + "model": "text-embedding-ada-002", + "usage": { + "prompt_tokens": 8, + "total_tokens": 8 + } + } + + /audio/speech: + post: + operationId: createSpeech + tags: + - Audio + summary: Generates audio from the input text. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateSpeechRequest" + responses: + "200": + description: OK + headers: + Transfer-Encoding: + schema: + type: string + description: chunked + content: + application/octet-stream: + schema: + type: string + format: binary + x-oaiMeta: + name: Create speech + group: audio + returns: The audio file content. + examples: + request: + curl: | + curl https://api.openai.com/v1/audio/speech \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "model": "tts-1", + "input": "The quick brown fox jumped over the lazy dog.", + "voice": "alloy" + }' \ + --output speech.mp3 + python: | + from pathlib import Path + import openai + + speech_file_path = Path(__file__).parent / "speech.mp3" + response = openai.audio.speech.create( + model="tts-1", + voice="alloy", + input="The quick brown fox jumped over the lazy dog." + ) + response.stream_to_file(speech_file_path) + node: | + import fs from "fs"; + import path from "path"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + const speechFile = path.resolve("./speech.mp3"); + + async function main() { + const mp3 = await openai.audio.speech.create({ + model: "tts-1", + voice: "alloy", + input: "Today is a wonderful day to build something people love!", + }); + console.log(speechFile); + const buffer = Buffer.from(await mp3.arrayBuffer()); + await fs.promises.writeFile(speechFile, buffer); + } + main(); + /audio/transcriptions: + post: + operationId: createTranscription + tags: + - Audio + summary: Transcribes audio into the input language. + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#/components/schemas/CreateTranscriptionRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + oneOf: + - $ref: "#/components/schemas/CreateTranscriptionResponseJson" + - $ref: "#/components/schemas/CreateTranscriptionResponseVerboseJson" + x-oaiMeta: + name: Create transcription + group: audio + returns: The [transcription object](/docs/api-reference/audio/json-object) or a [verbose transcription object](/docs/api-reference/audio/verbose-json-object). + examples: + - title: Default + request: + curl: | + curl https://api.openai.com/v1/audio/transcriptions \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F file="@/path/to/file/audio.mp3" \ + -F model="whisper-1" + python: | + from openai import OpenAI + client = OpenAI() + + audio_file = open("speech.mp3", "rb") + transcript = client.audio.transcriptions.create( + model="whisper-1", + file=audio_file + ) + node: | + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const transcription = await openai.audio.transcriptions.create({ + file: fs.createReadStream("audio.mp3"), + model: "whisper-1", + }); + + console.log(transcription.text); + } + main(); + response: &basic_transcription_response_example | + { + "text": "Imagine the wildest idea that you've ever had, and you're curious about how it might scale to something that's a 100, a 1,000 times bigger. This is a place where you can get to do that." + } + - title: Word timestamps + request: + curl: | + curl https://api.openai.com/v1/audio/transcriptions \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F file="@/path/to/file/audio.mp3" \ + -F "timestamp_granularities[]=word" \ + -F model="whisper-1" \ + -F response_format="verbose_json" + python: | + from openai import OpenAI + client = OpenAI() + + audio_file = open("speech.mp3", "rb") + transcript = client.audio.transcriptions.create( + file=audio_file, + model="whisper-1", + response_format="verbose_json", + timestamp_granularities=["word"] + ) + + print(transcript.words) + node: | + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const transcription = await openai.audio.transcriptions.create({ + file: fs.createReadStream("audio.mp3"), + model: "whisper-1", + response_format: "verbose_json", + timestamp_granularities: ["word"] + }); + + console.log(transcription.text); + } + main(); + response: | + { + "task": "transcribe", + "language": "english", + "duration": 8.470000267028809, + "text": "The beach was a popular spot on a hot summer day. People were swimming in the ocean, building sandcastles, and playing beach volleyball.", + "words": [ + { + "word": "The", + "start": 0.0, + "end": 0.23999999463558197 + }, + ... + { + "word": "volleyball", + "start": 7.400000095367432, + "end": 7.900000095367432 + } + ] + } + - title: Segment timestamps + request: + curl: | + curl https://api.openai.com/v1/audio/transcriptions \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F file="@/path/to/file/audio.mp3" \ + -F "timestamp_granularities[]=segment" \ + -F model="whisper-1" \ + -F response_format="verbose_json" + python: | + from openai import OpenAI + client = OpenAI() + + audio_file = open("speech.mp3", "rb") + transcript = client.audio.transcriptions.create( + file=audio_file, + model="whisper-1", + response_format="verbose_json", + timestamp_granularities=["segment"] + ) + + print(transcript.words) + node: | + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const transcription = await openai.audio.transcriptions.create({ + file: fs.createReadStream("audio.mp3"), + model: "whisper-1", + response_format: "verbose_json", + timestamp_granularities: ["segment"] + }); + + console.log(transcription.text); + } + main(); + response: &verbose_transcription_response_example | + { + "task": "transcribe", + "language": "english", + "duration": 8.470000267028809, + "text": "The beach was a popular spot on a hot summer day. People were swimming in the ocean, building sandcastles, and playing beach volleyball.", + "segments": [ + { + "id": 0, + "seek": 0, + "start": 0.0, + "end": 3.319999933242798, + "text": " The beach was a popular spot on a hot summer day.", + "tokens": [ + 50364, 440, 7534, 390, 257, 3743, 4008, 322, 257, 2368, 4266, 786, 13, 50530 + ], + "temperature": 0.0, + "avg_logprob": -0.2860786020755768, + "compression_ratio": 1.2363636493682861, + "no_speech_prob": 0.00985979475080967 + }, + ... + ] + } + /audio/translations: + post: + operationId: createTranslation + tags: + - Audio + summary: Translates audio into English. + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#/components/schemas/CreateTranslationRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + oneOf: + - $ref: "#/components/schemas/CreateTranslationResponseJson" + - $ref: "#/components/schemas/CreateTranslationResponseVerboseJson" + x-oaiMeta: + name: Create translation + group: audio + returns: The translated text. + examples: + request: + curl: | + curl https://api.openai.com/v1/audio/translations \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: multipart/form-data" \ + -F file="@/path/to/file/german.m4a" \ + -F model="whisper-1" + python: | + from openai import OpenAI + client = OpenAI() + + audio_file = open("speech.mp3", "rb") + transcript = client.audio.translations.create( + model="whisper-1", + file=audio_file + ) + node: | + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const translation = await openai.audio.translations.create({ + file: fs.createReadStream("speech.mp3"), + model: "whisper-1", + }); + + console.log(translation.text); + } + main(); + response: | + { + "text": "Hello, my name is Wolfgang and I come from Germany. Where are you heading today?" + } + + /files: + get: + operationId: listFiles + tags: + - Files + summary: Returns a list of files that belong to the user's organization. + parameters: + - in: query + name: purpose + required: false + schema: + type: string + description: Only return files with the given purpose. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListFilesResponse" + x-oaiMeta: + name: List files + group: files + returns: A list of [File](/docs/api-reference/files/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/files \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.files.list() + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const list = await openai.files.list(); + + for await (const file of list) { + console.log(file); + } + } + + main(); + response: | + { + "data": [ + { + "id": "file-abc123", + "object": "file", + "bytes": 175, + "created_at": 1613677385, + "filename": "salesOverview.pdf", + "purpose": "assistants", + }, + { + "id": "file-abc123", + "object": "file", + "bytes": 140, + "created_at": 1613779121, + "filename": "puppy.jsonl", + "purpose": "fine-tune", + } + ], + "object": "list" + } + post: + operationId: createFile + tags: + - Files + summary: | + Upload a file that can be used across various endpoints. Individual files can be up to 512 MB, and the size of all files uploaded by one organization can be up to 100 GB. + + The Assistants API supports files up to 2 million tokens and of specific file types. See the [Assistants Tools guide](/docs/assistants/tools) for details. + + The Fine-tuning API only supports `.jsonl` files. The input also has certain required formats for fine-tuning [chat](/docs/api-reference/fine-tuning/chat-input) or [completions](/docs/api-reference/fine-tuning/completions-input) models. + + The Batch API only supports `.jsonl` files up to 100 MB in size. The input also has a specific required [format](/docs/api-reference/batch/request-input). + + Please [contact us](https://help.openai.com/) if you need to increase these storage limits. + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#/components/schemas/CreateFileRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/OpenAIFile" + x-oaiMeta: + name: Upload file + group: files + returns: The uploaded [File](/docs/api-reference/files/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/files \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -F purpose="fine-tune" \ + -F file="@mydata.jsonl" + python: | + from openai import OpenAI + client = OpenAI() + + client.files.create( + file=open("mydata.jsonl", "rb"), + purpose="fine-tune" + ) + node.js: |- + import fs from "fs"; + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const file = await openai.files.create({ + file: fs.createReadStream("mydata.jsonl"), + purpose: "fine-tune", + }); + + console.log(file); + } + + main(); + response: | + { + "id": "file-abc123", + "object": "file", + "bytes": 120000, + "created_at": 1677610602, + "filename": "mydata.jsonl", + "purpose": "fine-tune", + } + /files/{file_id}: + delete: + operationId: deleteFile + tags: + - Files + summary: Delete a file. + parameters: + - in: path + name: file_id + required: true + schema: + type: string + description: The ID of the file to use for this request. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteFileResponse" + x-oaiMeta: + name: Delete file + group: files + returns: Deletion status. + examples: + request: + curl: | + curl https://api.openai.com/v1/files/file-abc123 \ + -X DELETE \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.files.delete("file-abc123") + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const file = await openai.files.del("file-abc123"); + + console.log(file); + } + + main(); + response: | + { + "id": "file-abc123", + "object": "file", + "deleted": true + } + get: + operationId: retrieveFile + tags: + - Files + summary: Returns information about a specific file. + parameters: + - in: path + name: file_id + required: true + schema: + type: string + description: The ID of the file to use for this request. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/OpenAIFile" + x-oaiMeta: + name: Retrieve file + group: files + returns: The [File](/docs/api-reference/files/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/files/file-abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.files.retrieve("file-abc123") + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const file = await openai.files.retrieve("file-abc123"); + + console.log(file); + } + + main(); + response: | + { + "id": "file-abc123", + "object": "file", + "bytes": 120000, + "created_at": 1677610602, + "filename": "mydata.jsonl", + "purpose": "fine-tune", + } + /files/{file_id}/content: + get: + operationId: downloadFile + tags: + - Files + summary: Returns the contents of the specified file. + parameters: + - in: path + name: file_id + required: true + schema: + type: string + description: The ID of the file to use for this request. + responses: + "200": + description: OK + content: + application/json: + schema: + type: string + x-oaiMeta: + name: Retrieve file content + group: files + returns: The file content. + examples: + request: + curl: | + curl https://api.openai.com/v1/files/file-abc123/content \ + -H "Authorization: Bearer $OPENAI_API_KEY" > file.jsonl + python: | + from openai import OpenAI + client = OpenAI() + + content = client.files.content("file-abc123") + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const file = await openai.files.content("file-abc123"); + + console.log(file); + } + + main(); + /uploads: + post: + operationId: createUpload + tags: + - Uploads + summary: | + Creates an intermediate [Upload](/docs/api-reference/uploads/object) object that you can add [Parts](/docs/api-reference/uploads/part-object) to. Currently, an Upload can accept at most 8 GB in total and expires after an hour after you create it. + + Once you complete the Upload, we will create a [File](/docs/api-reference/files/object) object that contains all the parts you uploaded. This File is usable in the rest of our platform as a regular File object. + + For certain `purpose`s, the correct `mime_type` must be specified. Please refer to documentation for the supported MIME types for your use case: + - [Assistants](/docs/assistants/tools/file-search/supported-files) + + For guidance on the proper filename extensions for each purpose, please follow the documentation on [creating a File](/docs/api-reference/files/create). + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateUploadRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/Upload" + x-oaiMeta: + name: Create upload + group: uploads + returns: The [Upload](/docs/api-reference/uploads/object) object with status `pending`. + examples: + request: + curl: | + curl https://api.openai.com/v1/uploads \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "purpose": "fine-tune", + "filename": "training_examples.jsonl", + "bytes": 2147483648, + "mime_type": "text/jsonl" + }' + response: | + { + "id": "upload_abc123", + "object": "upload", + "bytes": 2147483648, + "created_at": 1719184911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + "status": "pending", + "expires_at": 1719127296 + } + + /uploads/{upload_id}/parts: + post: + operationId: addUploadPart + tags: + - Uploads + summary: | + Adds a [Part](/docs/api-reference/uploads/part-object) to an [Upload](/docs/api-reference/uploads/object) object. A Part represents a chunk of bytes from the file you are trying to upload. + + Each Part can be at most 64 MB, and you can add Parts until you hit the Upload maximum of 8 GB. + + It is possible to add multiple Parts in parallel. You can decide the intended order of the Parts when you [complete the Upload](/docs/api-reference/uploads/complete). + parameters: + - in: path + name: upload_id + required: true + schema: + type: string + example: upload_abc123 + description: | + The ID of the Upload. + requestBody: + required: true + content: + multipart/form-data: + schema: + $ref: "#/components/schemas/AddUploadPartRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/UploadPart" + x-oaiMeta: + name: Add upload part + group: uploads + returns: The upload [Part](/docs/api-reference/uploads/part-object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/uploads/upload_abc123/parts + -F data="aHR0cHM6Ly9hcGkub3BlbmFpLmNvbS92MS91cGxvYWRz..." + response: | + { + "id": "part_def456", + "object": "upload.part", + "created_at": 1719185911, + "upload_id": "upload_abc123" + } + + /uploads/{upload_id}/complete: + post: + operationId: completeUpload + tags: + - Uploads + summary: | + Completes the [Upload](/docs/api-reference/uploads/object). + + Within the returned Upload object, there is a nested [File](/docs/api-reference/files/object) object that is ready to use in the rest of the platform. + + You can specify the order of the Parts by passing in an ordered list of the Part IDs. + + The number of bytes uploaded upon completion must match the number of bytes initially specified when creating the Upload object. No Parts may be added after an Upload is completed. + parameters: + - in: path + name: upload_id + required: true + schema: + type: string + example: upload_abc123 + description: | + The ID of the Upload. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CompleteUploadRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/Upload" + x-oaiMeta: + name: Complete upload + group: uploads + returns: The [Upload](/docs/api-reference/uploads/object) object with status `completed` with an additional `file` property containing the created usable File object. + examples: + request: + curl: | + curl https://api.openai.com/v1/uploads/upload_abc123/complete + -d '{ + "part_ids": ["part_def456", "part_ghi789"] + }' + response: | + { + "id": "upload_abc123", + "object": "upload", + "bytes": 2147483648, + "created_at": 1719184911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + "status": "completed", + "expires_at": 1719127296, + "file": { + "id": "file-xyz321", + "object": "file", + "bytes": 2147483648, + "created_at": 1719186911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + } + } + + /uploads/{upload_id}/cancel: + post: + operationId: cancelUpload + tags: + - Uploads + summary: | + Cancels the Upload. No Parts may be added after an Upload is cancelled. + parameters: + - in: path + name: upload_id + required: true + schema: + type: string + example: upload_abc123 + description: | + The ID of the Upload. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/Upload" + x-oaiMeta: + name: Cancel upload + group: uploads + returns: The [Upload](/docs/api-reference/uploads/object) object with status `cancelled`. + examples: + request: + curl: | + curl https://api.openai.com/v1/uploads/upload_abc123/cancel + response: | + { + "id": "upload_abc123", + "object": "upload", + "bytes": 2147483648, + "created_at": 1719184911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + "status": "cancelled", + "expires_at": 1719127296 + } + + /fine_tuning/jobs: + post: + operationId: createFineTuningJob + tags: + - Fine-tuning + summary: | + Creates a fine-tuning job which begins the process of creating a new model from a given dataset. + + Response includes details of the enqueued job including job status and the name of the fine-tuned models once complete. + + [Learn more about fine-tuning](/docs/guides/fine-tuning) + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateFineTuningJobRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/FineTuningJob" + x-oaiMeta: + name: Create fine-tuning job + group: fine-tuning + returns: A [fine-tuning.job](/docs/api-reference/fine-tuning/object) object. + examples: + - title: Default + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "training_file": "file-BK7bzQj3FfZFXr7DbL6xJwfo", + "model": "gpt-4o-mini" + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.create( + training_file="file-abc123", + model="gpt-4o-mini" + ) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const fineTune = await openai.fineTuning.jobs.create({ + training_file: "file-abc123" + }); + + console.log(fineTune); + } + + main(); + response: | + { + "object": "fine_tuning.job", + "id": "ftjob-abc123", + "model": "gpt-4o-mini-2024-07-18", + "created_at": 1721764800, + "fine_tuned_model": null, + "organization_id": "org-123", + "result_files": [], + "status": "queued", + "validation_file": null, + "training_file": "file-abc123", + } + - title: Epochs + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "training_file": "file-abc123", + "model": "gpt-4o-mini", + "hyperparameters": { + "n_epochs": 2 + } + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.create( + training_file="file-abc123", + model="gpt-4o-mini", + hyperparameters={ + "n_epochs":2 + } + ) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const fineTune = await openai.fineTuning.jobs.create({ + training_file: "file-abc123", + model: "gpt-4o-mini", + hyperparameters: { n_epochs: 2 } + }); + + console.log(fineTune); + } + + main(); + response: | + { + "object": "fine_tuning.job", + "id": "ftjob-abc123", + "model": "gpt-4o-mini-2024-07-18", + "created_at": 1721764800, + "fine_tuned_model": null, + "organization_id": "org-123", + "result_files": [], + "status": "queued", + "validation_file": null, + "training_file": "file-abc123", + "hyperparameters": {"n_epochs": 2}, + } + - title: Validation file + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "training_file": "file-abc123", + "validation_file": "file-abc123", + "model": "gpt-4o-mini" + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.create( + training_file="file-abc123", + validation_file="file-def456", + model="gpt-4o-mini" + ) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const fineTune = await openai.fineTuning.jobs.create({ + training_file: "file-abc123", + validation_file: "file-abc123" + }); + + console.log(fineTune); + } + + main(); + response: | + { + "object": "fine_tuning.job", + "id": "ftjob-abc123", + "model": "gpt-4o-mini-2024-07-18", + "created_at": 1721764800, + "fine_tuned_model": null, + "organization_id": "org-123", + "result_files": [], + "status": "queued", + "validation_file": "file-abc123", + "training_file": "file-abc123", + } + - title: W&B Integration + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "training_file": "file-abc123", + "validation_file": "file-abc123", + "model": "gpt-4o-mini", + "integrations": [ + { + "type": "wandb", + "wandb": { + "project": "my-wandb-project", + "name": "ft-run-display-name" + "tags": [ + "first-experiment", "v2" + ] + } + } + ] + }' + response: | + { + "object": "fine_tuning.job", + "id": "ftjob-abc123", + "model": "gpt-4o-mini-2024-07-18", + "created_at": 1721764800, + "fine_tuned_model": null, + "organization_id": "org-123", + "result_files": [], + "status": "queued", + "validation_file": "file-abc123", + "training_file": "file-abc123", + "integrations": [ + { + "type": "wandb", + "wandb": { + "project": "my-wandb-project", + "entity": None, + "run_id": "ftjob-abc123" + } + } + ] + } + get: + operationId: listPaginatedFineTuningJobs + tags: + - Fine-tuning + summary: | + List your organization's fine-tuning jobs + parameters: + - name: after + in: query + description: Identifier for the last job from the previous pagination request. + required: false + schema: + type: string + - name: limit + in: query + description: Number of fine-tuning jobs to retrieve. + required: false + schema: + type: integer + default: 20 + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListPaginatedFineTuningJobsResponse" + x-oaiMeta: + name: List fine-tuning jobs + group: fine-tuning + returns: A list of paginated [fine-tuning job](/docs/api-reference/fine-tuning/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs?limit=2 \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.list() + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const list = await openai.fineTuning.jobs.list(); + + for await (const fineTune of list) { + console.log(fineTune); + } + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "object": "fine_tuning.job.event", + "id": "ft-event-TjX0lMfOniCZX64t9PUQT5hn", + "created_at": 1689813489, + "level": "warn", + "message": "Fine tuning process stopping due to job cancellation", + "data": null, + "type": "message" + }, + { ... }, + { ... } + ], "has_more": true + } + /fine_tuning/jobs/{fine_tuning_job_id}: + get: + operationId: retrieveFineTuningJob + tags: + - Fine-tuning + summary: | + Get info about a fine-tuning job. + + [Learn more about fine-tuning](/docs/guides/fine-tuning) + parameters: + - in: path + name: fine_tuning_job_id + required: true + schema: + type: string + example: ft-AF1WoRqd3aJAHsqc9NY7iL8F + description: | + The ID of the fine-tuning job. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/FineTuningJob" + x-oaiMeta: + name: Retrieve fine-tuning job + group: fine-tuning + returns: The [fine-tuning](/docs/api-reference/fine-tuning/object) object with the given ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs/ft-AF1WoRqd3aJAHsqc9NY7iL8F \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.retrieve("ftjob-abc123") + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const fineTune = await openai.fineTuning.jobs.retrieve("ftjob-abc123"); + + console.log(fineTune); + } + + main(); + response: &fine_tuning_example | + { + "object": "fine_tuning.job", + "id": "ftjob-abc123", + "model": "davinci-002", + "created_at": 1692661014, + "finished_at": 1692661190, + "fine_tuned_model": "ft:davinci-002:my-org:custom_suffix:7q8mpxmy", + "organization_id": "org-123", + "result_files": [ + "file-abc123" + ], + "status": "succeeded", + "validation_file": null, + "training_file": "file-abc123", + "hyperparameters": { + "n_epochs": 4, + "batch_size": 1, + "learning_rate_multiplier": 1.0 + }, + "trained_tokens": 5768, + "integrations": [], + "seed": 0, + "estimated_finish": 0 + } + /fine_tuning/jobs/{fine_tuning_job_id}/events: + get: + operationId: listFineTuningEvents + tags: + - Fine-tuning + summary: | + Get status updates for a fine-tuning job. + parameters: + - in: path + name: fine_tuning_job_id + required: true + schema: + type: string + example: ft-AF1WoRqd3aJAHsqc9NY7iL8F + description: | + The ID of the fine-tuning job to get events for. + - name: after + in: query + description: Identifier for the last event from the previous pagination request. + required: false + schema: + type: string + - name: limit + in: query + description: Number of events to retrieve. + required: false + schema: + type: integer + default: 20 + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListFineTuningJobEventsResponse" + x-oaiMeta: + name: List fine-tuning events + group: fine-tuning + returns: A list of fine-tuning event objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs/ftjob-abc123/events \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.list_events( + fine_tuning_job_id="ftjob-abc123", + limit=2 + ) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const list = await openai.fineTuning.list_events(id="ftjob-abc123", limit=2); + + for await (const fineTune of list) { + console.log(fineTune); + } + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "object": "fine_tuning.job.event", + "id": "ft-event-ddTJfwuMVpfLXseO0Am0Gqjm", + "created_at": 1721764800, + "level": "info", + "message": "Fine tuning job successfully completed", + "data": null, + "type": "message" + }, + { + "object": "fine_tuning.job.event", + "id": "ft-event-tyiGuB72evQncpH87xe505Sv", + "created_at": 1721764800, + "level": "info", + "message": "New fine-tuned model created: ft:gpt-4o-mini:openai::7p4lURel", + "data": null, + "type": "message" + } + ], + "has_more": true + } + /fine_tuning/jobs/{fine_tuning_job_id}/cancel: + post: + operationId: cancelFineTuningJob + tags: + - Fine-tuning + summary: | + Immediately cancel a fine-tune job. + parameters: + - in: path + name: fine_tuning_job_id + required: true + schema: + type: string + example: ft-AF1WoRqd3aJAHsqc9NY7iL8F + description: | + The ID of the fine-tuning job to cancel. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/FineTuningJob" + x-oaiMeta: + name: Cancel fine-tuning + group: fine-tuning + returns: The cancelled [fine-tuning](/docs/api-reference/fine-tuning/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/fine_tuning/jobs/ftjob-abc123/cancel \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.fine_tuning.jobs.cancel("ftjob-abc123") + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const fineTune = await openai.fineTuning.jobs.cancel("ftjob-abc123"); + + console.log(fineTune); + } + main(); + response: | + { + "object": "fine_tuning.job", + "id": "ftjob-abc123", + "model": "gpt-4o-mini-2024-07-18", + "created_at": 1721764800, + "fine_tuned_model": null, + "organization_id": "org-123", + "result_files": [], + "hyperparameters": { + "n_epochs": "auto" + }, + "status": "cancelled", + "validation_file": "file-abc123", + "training_file": "file-abc123" + } + /fine_tuning/jobs/{fine_tuning_job_id}/checkpoints: + get: + operationId: listFineTuningJobCheckpoints + tags: + - Fine-tuning + summary: | + List checkpoints for a fine-tuning job. + parameters: + - in: path + name: fine_tuning_job_id + required: true + schema: + type: string + example: ft-AF1WoRqd3aJAHsqc9NY7iL8F + description: | + The ID of the fine-tuning job to get checkpoints for. + - name: after + in: query + description: Identifier for the last checkpoint ID from the previous pagination request. + required: false + schema: + type: string + - name: limit + in: query + description: Number of checkpoints to retrieve. + required: false + schema: + type: integer + default: 10 + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListFineTuningJobCheckpointsResponse" + x-oaiMeta: + name: List fine-tuning checkpoints + group: fine-tuning + returns: A list of fine-tuning [checkpoint objects](/docs/api-reference/fine-tuning/checkpoint-object) for a fine-tuning job. + examples: + request: + curl: | + curl https://api.openai.com/v1/fine_tuning/jobs/ftjob-abc123/checkpoints \ + -H "Authorization: Bearer $OPENAI_API_KEY" + response: | + { + "object": "list" + "data": [ + { + "object": "fine_tuning.job.checkpoint", + "id": "ftckpt_zc4Q7MP6XxulcVzj4MZdwsAB", + "created_at": 1721764867, + "fine_tuned_model_checkpoint": "ft:gpt-4o-mini-2024-07-18:my-org:custom-suffix:96olL566:ckpt-step-2000", + "metrics": { + "full_valid_loss": 0.134, + "full_valid_mean_token_accuracy": 0.874 + }, + "fine_tuning_job_id": "ftjob-abc123", + "step_number": 2000, + }, + { + "object": "fine_tuning.job.checkpoint", + "id": "ftckpt_enQCFmOTGj3syEpYVhBRLTSy", + "created_at": 1721764800, + "fine_tuned_model_checkpoint": "ft:gpt-4o-mini-2024-07-18:my-org:custom-suffix:7q8mpxmy:ckpt-step-1000", + "metrics": { + "full_valid_loss": 0.167, + "full_valid_mean_token_accuracy": 0.781 + }, + "fine_tuning_job_id": "ftjob-abc123", + "step_number": 1000, + }, + ], + "first_id": "ftckpt_zc4Q7MP6XxulcVzj4MZdwsAB", + "last_id": "ftckpt_enQCFmOTGj3syEpYVhBRLTSy", + "has_more": true + } + + /models: + get: + operationId: listModels + tags: + - Models + summary: Lists the currently available models, and provides basic information about each one such as the owner and availability. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListModelsResponse" + x-oaiMeta: + name: List models + group: models + returns: A list of [model](/docs/api-reference/models/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/models \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.models.list() + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const list = await openai.models.list(); + + for await (const model of list) { + console.log(model); + } + } + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "model-id-0", + "object": "model", + "created": 1686935002, + "owned_by": "organization-owner" + }, + { + "id": "model-id-1", + "object": "model", + "created": 1686935002, + "owned_by": "organization-owner", + }, + { + "id": "model-id-2", + "object": "model", + "created": 1686935002, + "owned_by": "openai" + }, + ], + "object": "list" + } + /models/{model}: + get: + operationId: retrieveModel + tags: + - Models + summary: Retrieves a model instance, providing basic information about the model such as the owner and permissioning. + parameters: + - in: path + name: model + required: true + schema: + type: string + # ideally this will be an actual ID, so this will always work from browser + example: gpt-4o-mini + description: The ID of the model to use for this request + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/Model" + x-oaiMeta: + name: Retrieve model + group: models + returns: The [model](/docs/api-reference/models/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/models/VAR_model_id \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.models.retrieve("VAR_model_id") + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const model = await openai.models.retrieve("VAR_model_id"); + + console.log(model); + } + + main(); + response: &retrieve_model_response | + { + "id": "VAR_model_id", + "object": "model", + "created": 1686935002, + "owned_by": "openai" + } + delete: + operationId: deleteModel + tags: + - Models + summary: Delete a fine-tuned model. You must have the Owner role in your organization to delete a model. + parameters: + - in: path + name: model + required: true + schema: + type: string + example: ft:gpt-4o-mini:acemeco:suffix:abc123 + description: The model to delete + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteModelResponse" + x-oaiMeta: + name: Delete a fine-tuned model + group: models + returns: Deletion status. + examples: + request: + curl: | + curl https://api.openai.com/v1/models/ft:gpt-4o-mini:acemeco:suffix:abc123 \ + -X DELETE \ + -H "Authorization: Bearer $OPENAI_API_KEY" + python: | + from openai import OpenAI + client = OpenAI() + + client.models.delete("ft:gpt-4o-mini:acemeco:suffix:abc123") + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const model = await openai.models.del("ft:gpt-4o-mini:acemeco:suffix:abc123"); + + console.log(model); + } + main(); + response: | + { + "id": "ft:gpt-4o-mini:acemeco:suffix:abc123", + "object": "model", + "deleted": true + } + + /moderations: + post: + operationId: createModeration + tags: + - Moderations + summary: Classifies if text is potentially harmful. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateModerationRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/CreateModerationResponse" + x-oaiMeta: + name: Create moderation + group: moderations + returns: A [moderation](/docs/api-reference/moderations/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/moderations \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -d '{ + "input": "I want to kill them." + }' + python: | + from openai import OpenAI + client = OpenAI() + + moderation = client.moderations.create(input="I want to kill them.") + print(moderation) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const moderation = await openai.moderations.create({ input: "I want to kill them." }); + + console.log(moderation); + } + main(); + response: &moderation_example | + { + "id": "modr-XXXXX", + "model": "text-moderation-005", + "results": [ + { + "flagged": true, + "categories": { + "sexual": false, + "hate": false, + "harassment": false, + "self-harm": false, + "sexual/minors": false, + "hate/threatening": false, + "violence/graphic": false, + "self-harm/intent": false, + "self-harm/instructions": false, + "harassment/threatening": true, + "violence": true, + }, + "category_scores": { + "sexual": 1.2282071e-06, + "hate": 0.010696256, + "harassment": 0.29842457, + "self-harm": 1.5236925e-08, + "sexual/minors": 5.7246268e-08, + "hate/threatening": 0.0060676364, + "violence/graphic": 4.435014e-06, + "self-harm/intent": 8.098441e-10, + "self-harm/instructions": 2.8498655e-11, + "harassment/threatening": 0.63055265, + "violence": 0.99011886, + } + } + ] + } + + /assistants: + get: + operationId: listAssistants + tags: + - Assistants + summary: Returns a list of assistants. + parameters: + - name: limit + in: query + description: &pagination_limit_param_description | + A limit on the number of objects to be returned. Limit can range between 1 and 100, and the default is 20. + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: &pagination_order_param_description | + Sort order by the `created_at` timestamp of the objects. `asc` for ascending order and `desc` for descending order. + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: &pagination_after_param_description | + A cursor for use in pagination. `after` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include after=obj_foo in order to fetch the next page of the list. + schema: + type: string + - name: before + in: query + description: &pagination_before_param_description | + A cursor for use in pagination. `before` is an object ID that defines your place in the list. For instance, if you make a list request and receive 100 objects, ending with obj_foo, your subsequent call can include before=obj_foo in order to fetch the previous page of the list. + schema: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListAssistantsResponse" + x-oaiMeta: + name: List assistants + group: assistants + beta: true + returns: A list of [assistant](/docs/api-reference/assistants/object) objects. + examples: + request: + curl: | + curl "https://api.openai.com/v1/assistants?order=desc&limit=20" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + my_assistants = client.beta.assistants.list( + order="desc", + limit="20", + ) + print(my_assistants.data) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const myAssistants = await openai.beta.assistants.list({ + order: "desc", + limit: "20", + }); + + console.log(myAssistants.data); + } + + main(); + response: &list_assistants_example | + { + "object": "list", + "data": [ + { + "id": "asst_abc123", + "object": "assistant", + "created_at": 1698982736, + "name": "Coding Tutor", + "description": null, + "model": "gpt-4o", + "instructions": "You are a helpful assistant designed to make me better at coding!", + "tools": [], + "tool_resources": {}, + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + }, + { + "id": "asst_abc456", + "object": "assistant", + "created_at": 1698982718, + "name": "My Assistant", + "description": null, + "model": "gpt-4o", + "instructions": "You are a helpful assistant designed to make me better at coding!", + "tools": [], + "tool_resources": {}, + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + }, + { + "id": "asst_abc789", + "object": "assistant", + "created_at": 1698982643, + "name": null, + "description": null, + "model": "gpt-4o", + "instructions": null, + "tools": [], + "tool_resources": {}, + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + } + ], + "first_id": "asst_abc123", + "last_id": "asst_abc789", + "has_more": false + } + post: + operationId: createAssistant + tags: + - Assistants + summary: Create an assistant with a model and instructions. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateAssistantRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/AssistantObject" + x-oaiMeta: + name: Create assistant + group: assistants + beta: true + returns: An [assistant](/docs/api-reference/assistants/object) object. + examples: + - title: Code Interpreter + request: + curl: | + curl "https://api.openai.com/v1/assistants" \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", + "name": "Math Tutor", + "tools": [{"type": "code_interpreter"}], + "model": "gpt-4o" + }' + + python: | + from openai import OpenAI + client = OpenAI() + + my_assistant = client.beta.assistants.create( + instructions="You are a personal math tutor. When asked a question, write and run Python code to answer the question.", + name="Math Tutor", + tools=[{"type": "code_interpreter"}], + model="gpt-4o", + ) + print(my_assistant) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const myAssistant = await openai.beta.assistants.create({ + instructions: + "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", + name: "Math Tutor", + tools: [{ type: "code_interpreter" }], + model: "gpt-4o", + }); + + console.log(myAssistant); + } + + main(); + response: &create_assistants_example | + { + "id": "asst_abc123", + "object": "assistant", + "created_at": 1698984975, + "name": "Math Tutor", + "description": null, + "model": "gpt-4o", + "instructions": "You are a personal math tutor. When asked a question, write and run Python code to answer the question.", + "tools": [ + { + "type": "code_interpreter" + } + ], + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + } + - title: Files + request: + curl: | + curl https://api.openai.com/v1/assistants \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies.", + "tools": [{"type": "file_search"}], + "tool_resources": {"file_search": {"vector_store_ids": ["vs_123"]}}, + "model": "gpt-4o" + }' + python: | + from openai import OpenAI + client = OpenAI() + + my_assistant = client.beta.assistants.create( + instructions="You are an HR bot, and you have access to files to answer employee questions about company policies.", + name="HR Helper", + tools=[{"type": "file_search"}], + tool_resources={"file_search": {"vector_store_ids": ["vs_123"]}}, + model="gpt-4o" + ) + print(my_assistant) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const myAssistant = await openai.beta.assistants.create({ + instructions: + "You are an HR bot, and you have access to files to answer employee questions about company policies.", + name: "HR Helper", + tools: [{ type: "file_search" }], + tool_resources: { + file_search: { + vector_store_ids: ["vs_123"] + } + }, + model: "gpt-4o" + }); + + console.log(myAssistant); + } + + main(); + response: | + { + "id": "asst_abc123", + "object": "assistant", + "created_at": 1699009403, + "name": "HR Helper", + "description": null, + "model": "gpt-4o", + "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies.", + "tools": [ + { + "type": "file_search" + } + ], + "tool_resources": { + "file_search": { + "vector_store_ids": ["vs_123"] + } + }, + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + } + + /assistants/{assistant_id}: + get: + operationId: getAssistant + tags: + - Assistants + summary: Retrieves an assistant. + parameters: + - in: path + name: assistant_id + required: true + schema: + type: string + description: The ID of the assistant to retrieve. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/AssistantObject" + x-oaiMeta: + name: Retrieve assistant + group: assistants + beta: true + returns: The [assistant](/docs/api-reference/assistants/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/assistants/asst_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + my_assistant = client.beta.assistants.retrieve("asst_abc123") + print(my_assistant) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const myAssistant = await openai.beta.assistants.retrieve( + "asst_abc123" + ); + + console.log(myAssistant); + } + + main(); + response: | + { + "id": "asst_abc123", + "object": "assistant", + "created_at": 1699009709, + "name": "HR Helper", + "description": null, + "model": "gpt-4o", + "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies.", + "tools": [ + { + "type": "file_search" + } + ], + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + } + post: + operationId: modifyAssistant + tags: + - Assistants + summary: Modifies an assistant. + parameters: + - in: path + name: assistant_id + required: true + schema: + type: string + description: The ID of the assistant to modify. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ModifyAssistantRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/AssistantObject" + x-oaiMeta: + name: Modify assistant + group: assistants + beta: true + returns: The modified [assistant](/docs/api-reference/assistants/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/assistants/asst_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies. Always response with info from either of the files.", + "tools": [{"type": "file_search"}], + "model": "gpt-4o" + }' + python: | + from openai import OpenAI + client = OpenAI() + + my_updated_assistant = client.beta.assistants.update( + "asst_abc123", + instructions="You are an HR bot, and you have access to files to answer employee questions about company policies. Always response with info from either of the files.", + name="HR Helper", + tools=[{"type": "file_search"}], + model="gpt-4o" + ) + + print(my_updated_assistant) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const myUpdatedAssistant = await openai.beta.assistants.update( + "asst_abc123", + { + instructions: + "You are an HR bot, and you have access to files to answer employee questions about company policies. Always response with info from either of the files.", + name: "HR Helper", + tools: [{ type: "file_search" }], + model: "gpt-4o" + } + ); + + console.log(myUpdatedAssistant); + } + + main(); + response: | + { + "id": "asst_123", + "object": "assistant", + "created_at": 1699009709, + "name": "HR Helper", + "description": null, + "model": "gpt-4o", + "instructions": "You are an HR bot, and you have access to files to answer employee questions about company policies. Always response with info from either of the files.", + "tools": [ + { + "type": "file_search" + } + ], + "tool_resources": { + "file_search": { + "vector_store_ids": [] + } + }, + "metadata": {}, + "top_p": 1.0, + "temperature": 1.0, + "response_format": "auto" + } + delete: + operationId: deleteAssistant + tags: + - Assistants + summary: Delete an assistant. + parameters: + - in: path + name: assistant_id + required: true + schema: + type: string + description: The ID of the assistant to delete. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteAssistantResponse" + x-oaiMeta: + name: Delete assistant + group: assistants + beta: true + returns: Deletion status + examples: + request: + curl: | + curl https://api.openai.com/v1/assistants/asst_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -X DELETE + python: | + from openai import OpenAI + client = OpenAI() + + response = client.beta.assistants.delete("asst_abc123") + print(response) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const response = await openai.beta.assistants.del("asst_abc123"); + + console.log(response); + } + main(); + response: | + { + "id": "asst_abc123", + "object": "assistant.deleted", + "deleted": true + } + + /threads: + post: + operationId: createThread + tags: + - Assistants + summary: Create a thread. + requestBody: + content: + application/json: + schema: + $ref: "#/components/schemas/CreateThreadRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ThreadObject" + x-oaiMeta: + name: Create thread + group: threads + beta: true + returns: A [thread](/docs/api-reference/threads) object. + examples: + - title: Empty + request: + curl: | + curl https://api.openai.com/v1/threads \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '' + python: | + from openai import OpenAI + client = OpenAI() + + empty_thread = client.beta.threads.create() + print(empty_thread) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const emptyThread = await openai.beta.threads.create(); + + console.log(emptyThread); + } + + main(); + response: | + { + "id": "thread_abc123", + "object": "thread", + "created_at": 1699012949, + "metadata": {}, + "tool_resources": {} + } + - title: Messages + request: + curl: | + curl https://api.openai.com/v1/threads \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "messages": [{ + "role": "user", + "content": "Hello, what is AI?" + }, { + "role": "user", + "content": "How does AI work? Explain it in simple terms." + }] + }' + python: | + from openai import OpenAI + client = OpenAI() + + message_thread = client.beta.threads.create( + messages=[ + { + "role": "user", + "content": "Hello, what is AI?" + }, + { + "role": "user", + "content": "How does AI work? Explain it in simple terms." + }, + ] + ) + + print(message_thread) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const messageThread = await openai.beta.threads.create({ + messages: [ + { + role: "user", + content: "Hello, what is AI?" + }, + { + role: "user", + content: "How does AI work? Explain it in simple terms.", + }, + ], + }); + + console.log(messageThread); + } + + main(); + response: | + { + "id": "thread_abc123", + "object": "thread", + "created_at": 1699014083, + "metadata": {}, + "tool_resources": {} + } + + /threads/{thread_id}: + get: + operationId: getThread + tags: + - Assistants + summary: Retrieves a thread. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to retrieve. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ThreadObject" + x-oaiMeta: + name: Retrieve thread + group: threads + beta: true + returns: The [thread](/docs/api-reference/threads/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + my_thread = client.beta.threads.retrieve("thread_abc123") + print(my_thread) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const myThread = await openai.beta.threads.retrieve( + "thread_abc123" + ); + + console.log(myThread); + } + + main(); + response: | + { + "id": "thread_abc123", + "object": "thread", + "created_at": 1699014083, + "metadata": {}, + "tool_resources": { + "code_interpreter": { + "file_ids": [] + } + } + } + post: + operationId: modifyThread + tags: + - Assistants + summary: Modifies a thread. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to modify. Only the `metadata` can be modified. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ModifyThreadRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ThreadObject" + x-oaiMeta: + name: Modify thread + group: threads + beta: true + returns: The modified [thread](/docs/api-reference/threads/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "metadata": { + "modified": "true", + "user": "abc123" + } + }' + python: | + from openai import OpenAI + client = OpenAI() + + my_updated_thread = client.beta.threads.update( + "thread_abc123", + metadata={ + "modified": "true", + "user": "abc123" + } + ) + print(my_updated_thread) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const updatedThread = await openai.beta.threads.update( + "thread_abc123", + { + metadata: { modified: "true", user: "abc123" }, + } + ); + + console.log(updatedThread); + } + + main(); + response: | + { + "id": "thread_abc123", + "object": "thread", + "created_at": 1699014083, + "metadata": { + "modified": "true", + "user": "abc123" + }, + "tool_resources": {} + } + delete: + operationId: deleteThread + tags: + - Assistants + summary: Delete a thread. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to delete. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteThreadResponse" + x-oaiMeta: + name: Delete thread + group: threads + beta: true + returns: Deletion status + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -X DELETE + python: | + from openai import OpenAI + client = OpenAI() + + response = client.beta.threads.delete("thread_abc123") + print(response) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const response = await openai.beta.threads.del("thread_abc123"); + + console.log(response); + } + main(); + response: | + { + "id": "thread_abc123", + "object": "thread.deleted", + "deleted": true + } + + /threads/{thread_id}/messages: + get: + operationId: listMessages + tags: + - Assistants + summary: Returns a list of messages for a given thread. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the [thread](/docs/api-reference/threads) the messages belong to. + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: *pagination_order_param_description + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + - name: run_id + in: query + description: | + Filter messages by the run ID that generated them. + schema: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListMessagesResponse" + x-oaiMeta: + name: List messages + group: threads + beta: true + returns: A list of [message](/docs/api-reference/messages) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/messages \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + thread_messages = client.beta.threads.messages.list("thread_abc123") + print(thread_messages.data) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const threadMessages = await openai.beta.threads.messages.list( + "thread_abc123" + ); + + console.log(threadMessages.data); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "msg_abc123", + "object": "thread.message", + "created_at": 1699016383, + "assistant_id": null, + "thread_id": "thread_abc123", + "run_id": null, + "role": "user", + "content": [ + { + "type": "text", + "text": { + "value": "How does AI work? Explain it in simple terms.", + "annotations": [] + } + } + ], + "attachments": [], + "metadata": {} + }, + { + "id": "msg_abc456", + "object": "thread.message", + "created_at": 1699016383, + "assistant_id": null, + "thread_id": "thread_abc123", + "run_id": null, + "role": "user", + "content": [ + { + "type": "text", + "text": { + "value": "Hello, what is AI?", + "annotations": [] + } + } + ], + "attachments": [], + "metadata": {} + } + ], + "first_id": "msg_abc123", + "last_id": "msg_abc456", + "has_more": false + } + post: + operationId: createMessage + tags: + - Assistants + summary: Create a message. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the [thread](/docs/api-reference/threads) to create a message for. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateMessageRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/MessageObject" + x-oaiMeta: + name: Create message + group: threads + beta: true + returns: A [message](/docs/api-reference/messages/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/messages \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "role": "user", + "content": "How does AI work? Explain it in simple terms." + }' + python: | + from openai import OpenAI + client = OpenAI() + + thread_message = client.beta.threads.messages.create( + "thread_abc123", + role="user", + content="How does AI work? Explain it in simple terms.", + ) + print(thread_message) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const threadMessages = await openai.beta.threads.messages.create( + "thread_abc123", + { role: "user", content: "How does AI work? Explain it in simple terms." } + ); + + console.log(threadMessages); + } + + main(); + response: | + { + "id": "msg_abc123", + "object": "thread.message", + "created_at": 1713226573, + "assistant_id": null, + "thread_id": "thread_abc123", + "run_id": null, + "role": "user", + "content": [ + { + "type": "text", + "text": { + "value": "How does AI work? Explain it in simple terms.", + "annotations": [] + } + } + ], + "attachments": [], + "metadata": {} + } + + /threads/{thread_id}/messages/{message_id}: + get: + operationId: getMessage + tags: + - Assistants + summary: Retrieve a message. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the [thread](/docs/api-reference/threads) to which this message belongs. + - in: path + name: message_id + required: true + schema: + type: string + description: The ID of the message to retrieve. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/MessageObject" + x-oaiMeta: + name: Retrieve message + group: threads + beta: true + returns: The [message](/docs/api-reference/messages/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/messages/msg_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + message = client.beta.threads.messages.retrieve( + message_id="msg_abc123", + thread_id="thread_abc123", + ) + print(message) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const message = await openai.beta.threads.messages.retrieve( + "thread_abc123", + "msg_abc123" + ); + + console.log(message); + } + + main(); + response: | + { + "id": "msg_abc123", + "object": "thread.message", + "created_at": 1699017614, + "assistant_id": null, + "thread_id": "thread_abc123", + "run_id": null, + "role": "user", + "content": [ + { + "type": "text", + "text": { + "value": "How does AI work? Explain it in simple terms.", + "annotations": [] + } + } + ], + "attachments": [], + "metadata": {} + } + post: + operationId: modifyMessage + tags: + - Assistants + summary: Modifies a message. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to which this message belongs. + - in: path + name: message_id + required: true + schema: + type: string + description: The ID of the message to modify. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ModifyMessageRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/MessageObject" + x-oaiMeta: + name: Modify message + group: threads + beta: true + returns: The modified [message](/docs/api-reference/messages/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/messages/msg_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "metadata": { + "modified": "true", + "user": "abc123" + } + }' + python: | + from openai import OpenAI + client = OpenAI() + + message = client.beta.threads.messages.update( + message_id="msg_abc12", + thread_id="thread_abc123", + metadata={ + "modified": "true", + "user": "abc123", + }, + ) + print(message) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const message = await openai.beta.threads.messages.update( + "thread_abc123", + "msg_abc123", + { + metadata: { + modified: "true", + user: "abc123", + }, + } + }' + response: | + { + "id": "msg_abc123", + "object": "thread.message", + "created_at": 1699017614, + "assistant_id": null, + "thread_id": "thread_abc123", + "run_id": null, + "role": "user", + "content": [ + { + "type": "text", + "text": { + "value": "How does AI work? Explain it in simple terms.", + "annotations": [] + } + } + ], + "file_ids": [], + "metadata": { + "modified": "true", + "user": "abc123" + } + } + delete: + operationId: deleteMessage + tags: + - Assistants + summary: Deletes a message. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to which this message belongs. + - in: path + name: message_id + required: true + schema: + type: string + description: The ID of the message to delete. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteMessageResponse" + x-oaiMeta: + name: Delete message + group: threads + beta: true + returns: Deletion status + examples: + request: + curl: | + curl -X DELETE https://api.openai.com/v1/threads/thread_abc123/messages/msg_abc123 \ + -H "Content-Type: application/json" \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + deleted_message = client.beta.threads.messages.delete( + message_id="msg_abc12", + thread_id="thread_abc123", + ) + print(deleted_message) + node.js: |- + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const deletedMessage = await openai.beta.threads.messages.del( + "thread_abc123", + "msg_abc123" + ); + + console.log(deletedMessage); + } + response: | + { + "id": "msg_abc123", + "object": "thread.message.deleted", + "deleted": true + } + + /threads/runs: + post: + operationId: createThreadAndRun + tags: + - Assistants + summary: Create a thread and run it in one request. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateThreadAndRunRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunObject" + x-oaiMeta: + name: Create thread and run + group: threads + beta: true + returns: A [run](/docs/api-reference/runs/object) object. + examples: + - title: Default + request: + curl: | + curl https://api.openai.com/v1/threads/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "assistant_id": "asst_abc123", + "thread": { + "messages": [ + {"role": "user", "content": "Explain deep learning to a 5 year old."} + ] + } + }' + python: | + from openai import OpenAI + client = OpenAI() + + run = client.beta.threads.create_and_run( + assistant_id="asst_abc123", + thread={ + "messages": [ + {"role": "user", "content": "Explain deep learning to a 5 year old."} + ] + } + ) + + print(run) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const run = await openai.beta.threads.createAndRun({ + assistant_id: "asst_abc123", + thread: { + messages: [ + { role: "user", content: "Explain deep learning to a 5 year old." }, + ], + }, + }); + + console.log(run); + } + + main(); + response: | + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1699076792, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "queued", + "started_at": null, + "expires_at": 1699077392, + "cancelled_at": null, + "failed_at": null, + "completed_at": null, + "required_action": null, + "last_error": null, + "model": "gpt-4o", + "instructions": "You are a helpful assistant.", + "tools": [], + "tool_resources": {}, + "metadata": {}, + "temperature": 1.0, + "top_p": 1.0, + "max_completion_tokens": null, + "max_prompt_tokens": null, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "incomplete_details": null, + "usage": null, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + + - title: Streaming + request: + curl: | + curl https://api.openai.com/v1/threads/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "assistant_id": "asst_123", + "thread": { + "messages": [ + {"role": "user", "content": "Hello"} + ] + }, + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + stream = client.beta.threads.create_and_run( + assistant_id="asst_123", + thread={ + "messages": [ + {"role": "user", "content": "Hello"} + ] + }, + stream=True + ) + + for event in stream: + print(event) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const stream = await openai.beta.threads.createAndRun({ + assistant_id: "asst_123", + thread: { + messages: [ + { role: "user", content: "Hello" }, + ], + }, + stream: true + }); + + for await (const event of stream) { + console.log(event); + } + } + + main(); + response: | + event: thread.created + data: {"id":"thread_123","object":"thread","created_at":1710348075,"metadata":{}} + + event: thread.run.created + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"tool_resources":{},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} + + event: thread.run.queued + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"tool_resources":{},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} + + event: thread.run.in_progress + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"tool_resources":{},"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} + + event: thread.run.step.created + data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} + + event: thread.run.step.in_progress + data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} + + event: thread.message.created + data: {"id":"msg_001","object":"thread.message","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[], "metadata":{}} + + event: thread.message.in_progress + data: {"id":"msg_001","object":"thread.message","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[], "metadata":{}} + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"Hello","annotations":[]}}]}} + + ... + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":" today"}}]}} + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"?"}}]}} + + event: thread.message.completed + data: {"id":"msg_001","object":"thread.message","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"completed","incomplete_details":null,"incomplete_at":null,"completed_at":1710348077,"role":"assistant","content":[{"type":"text","text":{"value":"Hello! How can I assist you today?","annotations":[]}}], "metadata":{}} + + event: thread.run.step.completed + data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1710348077,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31}} + + event: thread.run.completed + {"id":"run_123","object":"thread.run","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1713226836,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1713226837,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":345,"completion_tokens":11,"total_tokens":356},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true} + + event: done + data: [DONE] + + - title: Streaming with Functions + request: + curl: | + curl https://api.openai.com/v1/threads/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "assistant_id": "asst_abc123", + "thread": { + "messages": [ + {"role": "user", "content": "What is the weather like in San Francisco?"} + ] + }, + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + } + ], + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ] + + stream = client.beta.threads.create_and_run( + thread={ + "messages": [ + {"role": "user", "content": "What is the weather like in San Francisco?"} + ] + }, + assistant_id="asst_abc123", + tools=tools, + stream=True + ) + + for event in stream: + print(event) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + const tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ]; + + async function main() { + const stream = await openai.beta.threads.createAndRun({ + assistant_id: "asst_123", + thread: { + messages: [ + { role: "user", content: "What is the weather like in San Francisco?" }, + ], + }, + tools: tools, + stream: true + }); + + for await (const event of stream) { + console.log(event); + } + } + + main(); + response: | + event: thread.created + data: {"id":"thread_123","object":"thread","created_at":1710351818,"metadata":{}} + + event: thread.run.created + data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.queued + data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.in_progress + data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710351818,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.step.created + data: {"id":"step_001","object":"thread.run.step","created_at":1710351819,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"tool_calls","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710352418,"failed_at":null,"last_error":null,"step_details":{"type":"tool_calls","tool_calls":[]},"usage":null} + + event: thread.run.step.in_progress + data: {"id":"step_001","object":"thread.run.step","created_at":1710351819,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"tool_calls","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710352418,"failed_at":null,"last_error":null,"step_details":{"type":"tool_calls","tool_calls":[]},"usage":null} + + event: thread.run.step.delta + data: {"id":"step_001","object":"thread.run.step.delta","delta":{"step_details":{"type":"tool_calls","tool_calls":[{"index":0,"id":"call_XXNp8YGaFrjrSjgqxtC8JJ1B","type":"function","function":{"name":"get_current_weather","arguments":"","output":null}}]}}} + + event: thread.run.step.delta + data: {"id":"step_001","object":"thread.run.step.delta","delta":{"step_details":{"type":"tool_calls","tool_calls":[{"index":0,"type":"function","function":{"arguments":"{\""}}]}}} + + event: thread.run.step.delta + data: {"id":"step_001","object":"thread.run.step.delta","delta":{"step_details":{"type":"tool_calls","tool_calls":[{"index":0,"type":"function","function":{"arguments":"location"}}]}}} + + ... + + event: thread.run.step.delta + data: {"id":"step_001","object":"thread.run.step.delta","delta":{"step_details":{"type":"tool_calls","tool_calls":[{"index":0,"type":"function","function":{"arguments":"ahrenheit"}}]}}} + + event: thread.run.step.delta + data: {"id":"step_001","object":"thread.run.step.delta","delta":{"step_details":{"type":"tool_calls","tool_calls":[{"index":0,"type":"function","function":{"arguments":"\"}"}}]}}} + + event: thread.run.requires_action + data: {"id":"run_123","object":"thread.run","created_at":1710351818,"assistant_id":"asst_123","thread_id":"thread_123","status":"requires_action","started_at":1710351818,"expires_at":1710352418,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":{"type":"submit_tool_outputs","submit_tool_outputs":{"tool_calls":[{"id":"call_XXNp8YGaFrjrSjgqxtC8JJ1B","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\":\"San Francisco, CA\",\"unit\":\"fahrenheit\"}"}}]}},"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":345,"completion_tokens":11,"total_tokens":356},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: done + data: [DONE] + + /threads/{thread_id}/runs: + get: + operationId: listRuns + tags: + - Assistants + summary: Returns a list of runs belonging to a thread. + parameters: + - name: thread_id + in: path + required: true + schema: + type: string + description: The ID of the thread the run belongs to. + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: *pagination_order_param_description + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListRunsResponse" + x-oaiMeta: + name: List runs + group: threads + beta: true + returns: A list of [run](/docs/api-reference/runs/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + runs = client.beta.threads.runs.list( + "thread_abc123" + ) + + print(runs) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const runs = await openai.beta.threads.runs.list( + "thread_abc123" + ); + + console.log(runs); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1699075072, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "completed", + "started_at": 1699075072, + "expires_at": null, + "cancelled_at": null, + "failed_at": null, + "completed_at": 1699075073, + "last_error": null, + "model": "gpt-4o", + "instructions": null, + "incomplete_details": null, + "tools": [ + { + "type": "code_interpreter" + } + ], + "tool_resources": { + "code_interpreter": { + "file_ids": [ + "file-abc123", + "file-abc456" + ] + } + }, + "metadata": {}, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + }, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + }, + { + "id": "run_abc456", + "object": "thread.run", + "created_at": 1699063290, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "completed", + "started_at": 1699063290, + "expires_at": null, + "cancelled_at": null, + "failed_at": null, + "completed_at": 1699063291, + "last_error": null, + "model": "gpt-4o", + "instructions": null, + "incomplete_details": null, + "tools": [ + { + "type": "code_interpreter" + } + ], + "tool_resources": { + "code_interpreter": { + "file_ids": [ + "file-abc123", + "file-abc456" + ] + } + }, + "metadata": {}, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + }, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + ], + "first_id": "run_abc123", + "last_id": "run_abc456", + "has_more": false + } + post: + operationId: createRun + tags: + - Assistants + summary: Create a run. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to run. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateRunRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunObject" + x-oaiMeta: + name: Create run + group: threads + beta: true + returns: A [run](/docs/api-reference/runs/object) object. + examples: + - title: Default + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "assistant_id": "asst_abc123" + }' + python: | + from openai import OpenAI + client = OpenAI() + + run = client.beta.threads.runs.create( + thread_id="thread_abc123", + assistant_id="asst_abc123" + ) + + print(run) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const run = await openai.beta.threads.runs.create( + "thread_abc123", + { assistant_id: "asst_abc123" } + ); + + console.log(run); + } + + main(); + response: &run_object_example | + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1699063290, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "queued", + "started_at": 1699063290, + "expires_at": null, + "cancelled_at": null, + "failed_at": null, + "completed_at": 1699063291, + "last_error": null, + "model": "gpt-4o", + "instructions": null, + "incomplete_details": null, + "tools": [ + { + "type": "code_interpreter" + } + ], + "metadata": {}, + "usage": null, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + - title: Streaming + request: + curl: | + curl https://api.openai.com/v1/threads/thread_123/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "assistant_id": "asst_123", + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + stream = client.beta.threads.runs.create( + thread_id="thread_123", + assistant_id="asst_123", + stream=True + ) + + for event in stream: + print(event) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const stream = await openai.beta.threads.runs.create( + "thread_123", + { assistant_id: "asst_123", stream: true } + ); + + for await (const event of stream) { + console.log(event); + } + } + + main(); + response: | + event: thread.run.created + data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710331240,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.queued + data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710331240,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.in_progress + data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710330641,"expires_at":1710331240,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.step.created + data: {"id":"step_001","object":"thread.run.step","created_at":1710330641,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710331240,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} + + event: thread.run.step.in_progress + data: {"id":"step_001","object":"thread.run.step","created_at":1710330641,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710331240,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} + + event: thread.message.created + data: {"id":"msg_001","object":"thread.message","created_at":1710330641,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[],"metadata":{}} + + event: thread.message.in_progress + data: {"id":"msg_001","object":"thread.message","created_at":1710330641,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[],"metadata":{}} + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"Hello","annotations":[]}}]}} + + ... + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":" today"}}]}} + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"?"}}]}} + + event: thread.message.completed + data: {"id":"msg_001","object":"thread.message","created_at":1710330641,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"completed","incomplete_details":null,"incomplete_at":null,"completed_at":1710330642,"role":"assistant","content":[{"type":"text","text":{"value":"Hello! How can I assist you today?","annotations":[]}}],"metadata":{}} + + event: thread.run.step.completed + data: {"id":"step_001","object":"thread.run.step","created_at":1710330641,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1710330642,"expires_at":1710331240,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31}} + + event: thread.run.completed + data: {"id":"run_123","object":"thread.run","created_at":1710330640,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1710330641,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1710330642,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: done + data: [DONE] + + - title: Streaming with Functions + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "assistant_id": "asst_abc123", + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + } + ], + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ] + + stream = client.beta.threads.runs.create( + thread_id="thread_abc123", + assistant_id="asst_abc123", + tools=tools, + stream=True + ) + + for event in stream: + print(event) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + const tools = [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA", + }, + "unit": {"type": "string", "enum": ["celsius", "fahrenheit"]}, + }, + "required": ["location"], + }, + } + } + ]; + + async function main() { + const stream = await openai.beta.threads.runs.create( + "thread_abc123", + { + assistant_id: "asst_abc123", + tools: tools, + stream: true + } + ); + + for await (const event of stream) { + console.log(event); + } + } + + main(); + response: | + event: thread.run.created + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.queued + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":null,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.in_progress + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710348075,"expires_at":1710348675,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.step.created + data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} + + event: thread.run.step.in_progress + data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":null} + + event: thread.message.created + data: {"id":"msg_001","object":"thread.message","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[],"metadata":{}} + + event: thread.message.in_progress + data: {"id":"msg_001","object":"thread.message","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[],"metadata":{}} + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"Hello","annotations":[]}}]}} + + ... + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":" today"}}]}} + + event: thread.message.delta + data: {"id":"msg_001","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"?"}}]}} + + event: thread.message.completed + data: {"id":"msg_001","object":"thread.message","created_at":1710348076,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"completed","incomplete_details":null,"incomplete_at":null,"completed_at":1710348077,"role":"assistant","content":[{"type":"text","text":{"value":"Hello! How can I assist you today?","annotations":[]}}],"metadata":{}} + + event: thread.run.step.completed + data: {"id":"step_001","object":"thread.run.step","created_at":1710348076,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1710348077,"expires_at":1710348675,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_001"}},"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31}} + + event: thread.run.completed + data: {"id":"run_123","object":"thread.run","created_at":1710348075,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1710348075,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1710348077,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: done + data: [DONE] + + /threads/{thread_id}/runs/{run_id}: + get: + operationId: getRun + tags: + - Assistants + summary: Retrieves a run. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the [thread](/docs/api-reference/threads) that was run. + - in: path + name: run_id + required: true + schema: + type: string + description: The ID of the run to retrieve. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunObject" + x-oaiMeta: + name: Retrieve run + group: threads + beta: true + returns: The [run](/docs/api-reference/runs/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs/run_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + run = client.beta.threads.runs.retrieve( + thread_id="thread_abc123", + run_id="run_abc123" + ) + + print(run) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const run = await openai.beta.threads.runs.retrieve( + "thread_abc123", + "run_abc123" + ); + + console.log(run); + } + + main(); + response: | + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1699075072, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "completed", + "started_at": 1699075072, + "expires_at": null, + "cancelled_at": null, + "failed_at": null, + "completed_at": 1699075073, + "last_error": null, + "model": "gpt-4o", + "instructions": null, + "incomplete_details": null, + "tools": [ + { + "type": "code_interpreter" + } + ], + "metadata": {}, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + }, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + post: + operationId: modifyRun + tags: + - Assistants + summary: Modifies a run. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the [thread](/docs/api-reference/threads) that was run. + - in: path + name: run_id + required: true + schema: + type: string + description: The ID of the run to modify. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ModifyRunRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunObject" + x-oaiMeta: + name: Modify run + group: threads + beta: true + returns: The modified [run](/docs/api-reference/runs/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs/run_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "metadata": { + "user_id": "user_abc123" + } + }' + python: | + from openai import OpenAI + client = OpenAI() + + run = client.beta.threads.runs.update( + thread_id="thread_abc123", + run_id="run_abc123", + metadata={"user_id": "user_abc123"}, + ) + + print(run) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const run = await openai.beta.threads.runs.update( + "thread_abc123", + "run_abc123", + { + metadata: { + user_id: "user_abc123", + }, + } + ); + + console.log(run); + } + + main(); + response: | + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1699075072, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "completed", + "started_at": 1699075072, + "expires_at": null, + "cancelled_at": null, + "failed_at": null, + "completed_at": 1699075073, + "last_error": null, + "model": "gpt-4o", + "instructions": null, + "incomplete_details": null, + "tools": [ + { + "type": "code_interpreter" + } + ], + "tool_resources": { + "code_interpreter": { + "file_ids": [ + "file-abc123", + "file-abc456" + ] + } + }, + "metadata": { + "user_id": "user_abc123" + }, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + }, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + + /threads/{thread_id}/runs/{run_id}/submit_tool_outputs: + post: + operationId: submitToolOuputsToRun + tags: + - Assistants + summary: | + When a run has the `status: "requires_action"` and `required_action.type` is `submit_tool_outputs`, this endpoint can be used to submit the outputs from the tool calls once they're all completed. All outputs must be submitted in a single request. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the [thread](/docs/api-reference/threads) to which this run belongs. + - in: path + name: run_id + required: true + schema: + type: string + description: The ID of the run that requires the tool output submission. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/SubmitToolOutputsRunRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunObject" + x-oaiMeta: + name: Submit tool outputs to run + group: threads + beta: true + returns: The modified [run](/docs/api-reference/runs/object) object matching the specified ID. + examples: + - title: Default + request: + curl: | + curl https://api.openai.com/v1/threads/thread_123/runs/run_123/submit_tool_outputs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "tool_outputs": [ + { + "tool_call_id": "call_001", + "output": "70 degrees and sunny." + } + ] + }' + python: | + from openai import OpenAI + client = OpenAI() + + run = client.beta.threads.runs.submit_tool_outputs( + thread_id="thread_123", + run_id="run_123", + tool_outputs=[ + { + "tool_call_id": "call_001", + "output": "70 degrees and sunny." + } + ] + ) + + print(run) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const run = await openai.beta.threads.runs.submitToolOutputs( + "thread_123", + "run_123", + { + tool_outputs: [ + { + tool_call_id: "call_001", + output: "70 degrees and sunny.", + }, + ], + } + ); + + console.log(run); + } + + main(); + response: | + { + "id": "run_123", + "object": "thread.run", + "created_at": 1699075592, + "assistant_id": "asst_123", + "thread_id": "thread_123", + "status": "queued", + "started_at": 1699075592, + "expires_at": 1699076192, + "cancelled_at": null, + "failed_at": null, + "completed_at": null, + "last_error": null, + "model": "gpt-4o", + "instructions": null, + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather in a given location", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and state, e.g. San Francisco, CA" + }, + "unit": { + "type": "string", + "enum": ["celsius", "fahrenheit"] + } + }, + "required": ["location"] + } + } + } + ], + "metadata": {}, + "usage": null, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + + - title: Streaming + request: + curl: | + curl https://api.openai.com/v1/threads/thread_123/runs/run_123/submit_tool_outputs \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "tool_outputs": [ + { + "tool_call_id": "call_001", + "output": "70 degrees and sunny." + } + ], + "stream": true + }' + python: | + from openai import OpenAI + client = OpenAI() + + stream = client.beta.threads.runs.submit_tool_outputs( + thread_id="thread_123", + run_id="run_123", + tool_outputs=[ + { + "tool_call_id": "call_001", + "output": "70 degrees and sunny." + } + ], + stream=True + ) + + for event in stream: + print(event) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const stream = await openai.beta.threads.runs.submitToolOutputs( + "thread_123", + "run_123", + { + tool_outputs: [ + { + tool_call_id: "call_001", + output: "70 degrees and sunny.", + }, + ], + } + ); + + for await (const event of stream) { + console.log(event); + } + } + + main(); + response: | + event: thread.run.step.completed + data: {"id":"step_001","object":"thread.run.step","created_at":1710352449,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"tool_calls","status":"completed","cancelled_at":null,"completed_at":1710352475,"expires_at":1710353047,"failed_at":null,"last_error":null,"step_details":{"type":"tool_calls","tool_calls":[{"id":"call_iWr0kQ2EaYMaxNdl0v3KYkx7","type":"function","function":{"name":"get_current_weather","arguments":"{\"location\":\"San Francisco, CA\",\"unit\":\"fahrenheit\"}","output":"70 degrees and sunny."}}]},"usage":{"prompt_tokens":291,"completion_tokens":24,"total_tokens":315}} + + event: thread.run.queued + data: {"id":"run_123","object":"thread.run","created_at":1710352447,"assistant_id":"asst_123","thread_id":"thread_123","status":"queued","started_at":1710352448,"expires_at":1710353047,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.in_progress + data: {"id":"run_123","object":"thread.run","created_at":1710352447,"assistant_id":"asst_123","thread_id":"thread_123","status":"in_progress","started_at":1710352475,"expires_at":1710353047,"cancelled_at":null,"failed_at":null,"completed_at":null,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":null,"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: thread.run.step.created + data: {"id":"step_002","object":"thread.run.step","created_at":1710352476,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710353047,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_002"}},"usage":null} + + event: thread.run.step.in_progress + data: {"id":"step_002","object":"thread.run.step","created_at":1710352476,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"in_progress","cancelled_at":null,"completed_at":null,"expires_at":1710353047,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_002"}},"usage":null} + + event: thread.message.created + data: {"id":"msg_002","object":"thread.message","created_at":1710352476,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[],"metadata":{}} + + event: thread.message.in_progress + data: {"id":"msg_002","object":"thread.message","created_at":1710352476,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"in_progress","incomplete_details":null,"incomplete_at":null,"completed_at":null,"role":"assistant","content":[],"metadata":{}} + + event: thread.message.delta + data: {"id":"msg_002","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"The","annotations":[]}}]}} + + event: thread.message.delta + data: {"id":"msg_002","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":" current"}}]}} + + event: thread.message.delta + data: {"id":"msg_002","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":" weather"}}]}} + + ... + + event: thread.message.delta + data: {"id":"msg_002","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":" sunny"}}]}} + + event: thread.message.delta + data: {"id":"msg_002","object":"thread.message.delta","delta":{"content":[{"index":0,"type":"text","text":{"value":"."}}]}} + + event: thread.message.completed + data: {"id":"msg_002","object":"thread.message","created_at":1710352476,"assistant_id":"asst_123","thread_id":"thread_123","run_id":"run_123","status":"completed","incomplete_details":null,"incomplete_at":null,"completed_at":1710352477,"role":"assistant","content":[{"type":"text","text":{"value":"The current weather in San Francisco, CA is 70 degrees Fahrenheit and sunny.","annotations":[]}}],"metadata":{}} + + event: thread.run.step.completed + data: {"id":"step_002","object":"thread.run.step","created_at":1710352476,"run_id":"run_123","assistant_id":"asst_123","thread_id":"thread_123","type":"message_creation","status":"completed","cancelled_at":null,"completed_at":1710352477,"expires_at":1710353047,"failed_at":null,"last_error":null,"step_details":{"type":"message_creation","message_creation":{"message_id":"msg_002"}},"usage":{"prompt_tokens":329,"completion_tokens":18,"total_tokens":347}} + + event: thread.run.completed + data: {"id":"run_123","object":"thread.run","created_at":1710352447,"assistant_id":"asst_123","thread_id":"thread_123","status":"completed","started_at":1710352475,"expires_at":null,"cancelled_at":null,"failed_at":null,"completed_at":1710352477,"required_action":null,"last_error":null,"model":"gpt-4o","instructions":null,"tools":[{"type":"function","function":{"name":"get_current_weather","description":"Get the current weather in a given location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The city and state, e.g. San Francisco, CA"},"unit":{"type":"string","enum":["celsius","fahrenheit"]}},"required":["location"]}}}],"metadata":{},"temperature":1.0,"top_p":1.0,"max_completion_tokens":null,"max_prompt_tokens":null,"truncation_strategy":{"type":"auto","last_messages":null},"incomplete_details":null,"usage":{"prompt_tokens":20,"completion_tokens":11,"total_tokens":31},"response_format":"auto","tool_choice":"auto","parallel_tool_calls":true}} + + event: done + data: [DONE] + + /threads/{thread_id}/runs/{run_id}/cancel: + post: + operationId: cancelRun + tags: + - Assistants + summary: Cancels a run that is `in_progress`. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to which this run belongs. + - in: path + name: run_id + required: true + schema: + type: string + description: The ID of the run to cancel. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunObject" + x-oaiMeta: + name: Cancel a run + group: threads + beta: true + returns: The modified [run](/docs/api-reference/runs/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs/run_abc123/cancel \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "OpenAI-Beta: assistants=v2" \ + -X POST + python: | + from openai import OpenAI + client = OpenAI() + + run = client.beta.threads.runs.cancel( + thread_id="thread_abc123", + run_id="run_abc123" + ) + + print(run) + node.js: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const run = await openai.beta.threads.runs.cancel( + "thread_abc123", + "run_abc123" + ); + + console.log(run); + } + + main(); + response: | + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1699076126, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "cancelling", + "started_at": 1699076126, + "expires_at": 1699076726, + "cancelled_at": null, + "failed_at": null, + "completed_at": null, + "last_error": null, + "model": "gpt-4o", + "instructions": "You summarize books.", + "tools": [ + { + "type": "file_search" + } + ], + "tool_resources": { + "file_search": { + "vector_store_ids": ["vs_123"] + } + }, + "metadata": {}, + "usage": null, + "temperature": 1.0, + "top_p": 1.0, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + + /threads/{thread_id}/runs/{run_id}/steps: + get: + operationId: listRunSteps + tags: + - Assistants + summary: Returns a list of run steps belonging to a run. + parameters: + - name: thread_id + in: path + required: true + schema: + type: string + description: The ID of the thread the run and run steps belong to. + - name: run_id + in: path + required: true + schema: + type: string + description: The ID of the run the run steps belong to. + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: *pagination_order_param_description + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListRunStepsResponse" + x-oaiMeta: + name: List run steps + group: threads + beta: true + returns: A list of [run step](/docs/api-reference/runs/step-object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs/run_abc123/steps \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + run_steps = client.beta.threads.runs.steps.list( + thread_id="thread_abc123", + run_id="run_abc123" + ) + + print(run_steps) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const runStep = await openai.beta.threads.runs.steps.list( + "thread_abc123", + "run_abc123" + ); + console.log(runStep); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "step_abc123", + "object": "thread.run.step", + "created_at": 1699063291, + "run_id": "run_abc123", + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "type": "message_creation", + "status": "completed", + "cancelled_at": null, + "completed_at": 1699063291, + "expired_at": null, + "failed_at": null, + "last_error": null, + "step_details": { + "type": "message_creation", + "message_creation": { + "message_id": "msg_abc123" + } + }, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + } + } + ], + "first_id": "step_abc123", + "last_id": "step_abc456", + "has_more": false + } + + /threads/{thread_id}/runs/{run_id}/steps/{step_id}: + get: + operationId: getRunStep + tags: + - Assistants + summary: Retrieves a run step. + parameters: + - in: path + name: thread_id + required: true + schema: + type: string + description: The ID of the thread to which the run and run step belongs. + - in: path + name: run_id + required: true + schema: + type: string + description: The ID of the run to which the run step belongs. + - in: path + name: step_id + required: true + schema: + type: string + description: The ID of the run step to retrieve. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/RunStepObject" + x-oaiMeta: + name: Retrieve run step + group: threads + beta: true + returns: The [run step](/docs/api-reference/runs/step-object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/threads/thread_abc123/runs/run_abc123/steps/step_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + run_step = client.beta.threads.runs.steps.retrieve( + thread_id="thread_abc123", + run_id="run_abc123", + step_id="step_abc123" + ) + + print(run_step) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const runStep = await openai.beta.threads.runs.steps.retrieve( + "thread_abc123", + "run_abc123", + "step_abc123" + ); + console.log(runStep); + } + + main(); + response: &run_step_object_example | + { + "id": "step_abc123", + "object": "thread.run.step", + "created_at": 1699063291, + "run_id": "run_abc123", + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "type": "message_creation", + "status": "completed", + "cancelled_at": null, + "completed_at": 1699063291, + "expired_at": null, + "failed_at": null, + "last_error": null, + "step_details": { + "type": "message_creation", + "message_creation": { + "message_id": "msg_abc123" + } + }, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + } + } + + /vector_stores: + get: + operationId: listVectorStores + tags: + - Vector Stores + summary: Returns a list of vector stores. + parameters: + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: *pagination_order_param_description + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListVectorStoresResponse" + x-oaiMeta: + name: List vector stores + group: vector_stores + beta: true + returns: A list of [vector store](/docs/api-reference/vector-stores/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + vector_stores = client.beta.vector_stores.list() + print(vector_stores) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStores = await openai.beta.vectorStores.list(); + console.log(vectorStores); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "vs_abc123", + "object": "vector_store", + "created_at": 1699061776, + "name": "Support FAQ", + "bytes": 139920, + "file_counts": { + "in_progress": 0, + "completed": 3, + "failed": 0, + "cancelled": 0, + "total": 3 + } + }, + { + "id": "vs_abc456", + "object": "vector_store", + "created_at": 1699061776, + "name": "Support FAQ v2", + "bytes": 139920, + "file_counts": { + "in_progress": 0, + "completed": 3, + "failed": 0, + "cancelled": 0, + "total": 3 + } + } + ], + "first_id": "vs_abc123", + "last_id": "vs_abc456", + "has_more": false + } + post: + operationId: createVectorStore + tags: + - Vector Stores + summary: Create a vector store. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateVectorStoreRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreObject" + x-oaiMeta: + name: Create vector store + group: vector_stores + beta: true + returns: A [vector store](/docs/api-reference/vector-stores/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + -d '{ + "name": "Support FAQ" + }' + python: | + from openai import OpenAI + client = OpenAI() + + vector_store = client.beta.vector_stores.create( + name="Support FAQ" + ) + print(vector_store) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStore = await openai.beta.vectorStores.create({ + name: "Support FAQ" + }); + console.log(vectorStore); + } + + main(); + response: | + { + "id": "vs_abc123", + "object": "vector_store", + "created_at": 1699061776, + "name": "Support FAQ", + "bytes": 139920, + "file_counts": { + "in_progress": 0, + "completed": 3, + "failed": 0, + "cancelled": 0, + "total": 3 + } + } + + /vector_stores/{vector_store_id}: + get: + operationId: getVectorStore + tags: + - Vector Stores + summary: Retrieves a vector store. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + description: The ID of the vector store to retrieve. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreObject" + x-oaiMeta: + name: Retrieve vector store + group: vector_stores + beta: true + returns: The [vector store](/docs/api-reference/vector-stores/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + vector_store = client.beta.vector_stores.retrieve( + vector_store_id="vs_abc123" + ) + print(vector_store) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStore = await openai.beta.vectorStores.retrieve( + "vs_abc123" + ); + console.log(vectorStore); + } + + main(); + response: | + { + "id": "vs_abc123", + "object": "vector_store", + "created_at": 1699061776 + } + post: + operationId: modifyVectorStore + tags: + - Vector Stores + summary: Modifies a vector store. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + description: The ID of the vector store to modify. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/UpdateVectorStoreRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreObject" + x-oaiMeta: + name: Modify vector store + group: vector_stores + beta: true + returns: The modified [vector store](/docs/api-reference/vector-stores/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + -d '{ + "name": "Support FAQ" + }' + python: | + from openai import OpenAI + client = OpenAI() + + vector_store = client.beta.vector_stores.update( + vector_store_id="vs_abc123", + name="Support FAQ" + ) + print(vector_store) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStore = await openai.beta.vectorStores.update( + "vs_abc123", + { + name: "Support FAQ" + } + ); + console.log(vectorStore); + } + + main(); + response: | + { + "id": "vs_abc123", + "object": "vector_store", + "created_at": 1699061776, + "name": "Support FAQ", + "bytes": 139920, + "file_counts": { + "in_progress": 0, + "completed": 3, + "failed": 0, + "cancelled": 0, + "total": 3 + } + } + + delete: + operationId: deleteVectorStore + tags: + - Vector Stores + summary: Delete a vector store. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + description: The ID of the vector store to delete. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteVectorStoreResponse" + x-oaiMeta: + name: Delete vector store + group: vector_stores + beta: true + returns: Deletion status + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -X DELETE + python: | + from openai import OpenAI + client = OpenAI() + + deleted_vector_store = client.beta.vector_stores.delete( + vector_store_id="vs_abc123" + ) + print(deleted_vector_store) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const deletedVectorStore = await openai.beta.vectorStores.del( + "vs_abc123" + ); + console.log(deletedVectorStore); + } + + main(); + response: | + { + id: "vs_abc123", + object: "vector_store.deleted", + deleted: true + } + + /vector_stores/{vector_store_id}/files: + get: + operationId: listVectorStoreFiles + tags: + - Vector Stores + summary: Returns a list of vector store files. + parameters: + - name: vector_store_id + in: path + description: The ID of the vector store that the files belong to. + required: true + schema: + type: string + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: *pagination_order_param_description + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + - name: filter + in: query + description: "Filter by file status. One of `in_progress`, `completed`, `failed`, `cancelled`." + schema: + type: string + enum: ["in_progress", "completed", "failed", "cancelled"] + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListVectorStoreFilesResponse" + x-oaiMeta: + name: List vector store files + group: vector_stores + beta: true + returns: A list of [vector store file](/docs/api-reference/vector-stores-files/file-object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + vector_store_files = client.beta.vector_stores.files.list( + vector_store_id="vs_abc123" + ) + print(vector_store_files) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStoreFiles = await openai.beta.vectorStores.files.list( + "vs_abc123" + ); + console.log(vectorStoreFiles); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "file-abc123", + "object": "vector_store.file", + "created_at": 1699061776, + "vector_store_id": "vs_abc123" + }, + { + "id": "file-abc456", + "object": "vector_store.file", + "created_at": 1699061776, + "vector_store_id": "vs_abc123" + } + ], + "first_id": "file-abc123", + "last_id": "file-abc456", + "has_more": false + } + post: + operationId: createVectorStoreFile + tags: + - Vector Stores + summary: Create a vector store file by attaching a [File](/docs/api-reference/files) to a [vector store](/docs/api-reference/vector-stores/object). + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + example: vs_abc123 + description: | + The ID of the vector store for which to create a File. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateVectorStoreFileRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreFileObject" + x-oaiMeta: + name: Create vector store file + group: vector_stores + beta: true + returns: A [vector store file](/docs/api-reference/vector-stores-files/file-object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "file_id": "file-abc123" + }' + python: | + from openai import OpenAI + client = OpenAI() + + vector_store_file = client.beta.vector_stores.files.create( + vector_store_id="vs_abc123", + file_id="file-abc123" + ) + print(vector_store_file) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const myVectorStoreFile = await openai.beta.vectorStores.files.create( + "vs_abc123", + { + file_id: "file-abc123" + } + ); + console.log(myVectorStoreFile); + } + + main(); + response: | + { + "id": "file-abc123", + "object": "vector_store.file", + "created_at": 1699061776, + "usage_bytes": 1234, + "vector_store_id": "vs_abcd", + "status": "completed", + "last_error": null + } + + /vector_stores/{vector_store_id}/files/{file_id}: + get: + operationId: getVectorStoreFile + tags: + - Vector Stores + summary: Retrieves a vector store file. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + example: vs_abc123 + description: The ID of the vector store that the file belongs to. + - in: path + name: file_id + required: true + schema: + type: string + example: file-abc123 + description: The ID of the file being retrieved. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreFileObject" + x-oaiMeta: + name: Retrieve vector store file + group: vector_stores + beta: true + returns: The [vector store file](/docs/api-reference/vector-stores-files/file-object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files/file-abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + vector_store_file = client.beta.vector_stores.files.retrieve( + vector_store_id="vs_abc123", + file_id="file-abc123" + ) + print(vector_store_file) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStoreFile = await openai.beta.vectorStores.files.retrieve( + "vs_abc123", + "file-abc123" + ); + console.log(vectorStoreFile); + } + + main(); + response: | + { + "id": "file-abc123", + "object": "vector_store.file", + "created_at": 1699061776, + "vector_store_id": "vs_abcd", + "status": "completed", + "last_error": null + } + delete: + operationId: deleteVectorStoreFile + tags: + - Vector Stores + summary: Delete a vector store file. This will remove the file from the vector store but the file itself will not be deleted. To delete the file, use the [delete file](/docs/api-reference/files/delete) endpoint. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + description: The ID of the vector store that the file belongs to. + - in: path + name: file_id + required: true + schema: + type: string + description: The ID of the file to delete. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/DeleteVectorStoreFileResponse" + x-oaiMeta: + name: Delete vector store file + group: vector_stores + beta: true + returns: Deletion status + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files/file-abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -X DELETE + python: | + from openai import OpenAI + client = OpenAI() + + deleted_vector_store_file = client.beta.vector_stores.files.delete( + vector_store_id="vs_abc123", + file_id="file-abc123" + ) + print(deleted_vector_store_file) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const deletedVectorStoreFile = await openai.beta.vectorStores.files.del( + "vs_abc123", + "file-abc123" + ); + console.log(deletedVectorStoreFile); + } + + main(); + response: | + { + id: "file-abc123", + object: "vector_store.file.deleted", + deleted: true + } + + /vector_stores/{vector_store_id}/file_batches: + post: + operationId: createVectorStoreFileBatch + tags: + - Vector Stores + summary: Create a vector store file batch. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + example: vs_abc123 + description: | + The ID of the vector store for which to create a File Batch. + requestBody: + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/CreateVectorStoreFileBatchRequest" + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreFileBatchObject" + x-oaiMeta: + name: Create vector store file batch + group: vector_stores + beta: true + returns: A [vector store file batch](/docs/api-reference/vector-stores-file-batches/batch-object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/file_batches \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json \ + -H "OpenAI-Beta: assistants=v2" \ + -d '{ + "file_ids": ["file-abc123", "file-abc456"] + }' + python: | + from openai import OpenAI + client = OpenAI() + + vector_store_file_batch = client.beta.vector_stores.file_batches.create( + vector_store_id="vs_abc123", + file_ids=["file-abc123", "file-abc456"] + ) + print(vector_store_file_batch) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const myVectorStoreFileBatch = await openai.beta.vectorStores.fileBatches.create( + "vs_abc123", + { + file_ids: ["file-abc123", "file-abc456"] + } + ); + console.log(myVectorStoreFileBatch); + } + + main(); + response: | + { + "id": "vsfb_abc123", + "object": "vector_store.file_batch", + "created_at": 1699061776, + "vector_store_id": "vs_abc123", + "status": "in_progress", + "file_counts": { + "in_progress": 1, + "completed": 1, + "failed": 0, + "cancelled": 0, + "total": 0, + } + } + + /vector_stores/{vector_store_id}/file_batches/{batch_id}: + get: + operationId: getVectorStoreFileBatch + tags: + - Vector Stores + summary: Retrieves a vector store file batch. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + example: vs_abc123 + description: The ID of the vector store that the file batch belongs to. + - in: path + name: batch_id + required: true + schema: + type: string + example: vsfb_abc123 + description: The ID of the file batch being retrieved. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreFileBatchObject" + x-oaiMeta: + name: Retrieve vector store file batch + group: vector_stores + beta: true + returns: The [vector store file batch](/docs/api-reference/vector-stores-file-batches/batch-object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files_batches/vsfb_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + vector_store_file_batch = client.beta.vector_stores.file_batches.retrieve( + vector_store_id="vs_abc123", + batch_id="vsfb_abc123" + ) + print(vector_store_file_batch) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStoreFileBatch = await openai.beta.vectorStores.fileBatches.retrieve( + "vs_abc123", + "vsfb_abc123" + ); + console.log(vectorStoreFileBatch); + } + + main(); + response: | + { + "id": "vsfb_abc123", + "object": "vector_store.file_batch", + "created_at": 1699061776, + "vector_store_id": "vs_abc123", + "status": "in_progress", + "file_counts": { + "in_progress": 1, + "completed": 1, + "failed": 0, + "cancelled": 0, + "total": 0, + } + } + + /vector_stores/{vector_store_id}/file_batches/{batch_id}/cancel: + post: + operationId: cancelVectorStoreFileBatch + tags: + - Vector Stores + summary: Cancel a vector store file batch. This attempts to cancel the processing of files in this batch as soon as possible. + parameters: + - in: path + name: vector_store_id + required: true + schema: + type: string + description: The ID of the vector store that the file batch belongs to. + - in: path + name: batch_id + required: true + schema: + type: string + description: The ID of the file batch to cancel. + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/VectorStoreFileBatchObject" + x-oaiMeta: + name: Cancel vector store file batch + group: vector_stores + beta: true + returns: The modified vector store file batch object. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files_batches/vsfb_abc123/cancel \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" \ + -X POST + python: | + from openai import OpenAI + client = OpenAI() + + deleted_vector_store_file_batch = client.beta.vector_stores.file_batches.cancel( + vector_store_id="vs_abc123", + file_batch_id="vsfb_abc123" + ) + print(deleted_vector_store_file_batch) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const deletedVectorStoreFileBatch = await openai.vector_stores.fileBatches.cancel( + "vs_abc123", + "vsfb_abc123" + ); + console.log(deletedVectorStoreFileBatch); + } + + main(); + response: | + { + "id": "vsfb_abc123", + "object": "vector_store.file_batch", + "created_at": 1699061776, + "vector_store_id": "vs_abc123", + "status": "cancelling", + "file_counts": { + "in_progress": 12, + "completed": 3, + "failed": 0, + "cancelled": 0, + "total": 15, + } + } + + /vector_stores/{vector_store_id}/file_batches/{batch_id}/files: + get: + operationId: listFilesInVectorStoreBatch + tags: + - Vector Stores + summary: Returns a list of vector store files in a batch. + parameters: + - name: vector_store_id + in: path + description: The ID of the vector store that the files belong to. + required: true + schema: + type: string + - name: batch_id + in: path + description: The ID of the file batch that the files belong to. + required: true + schema: + type: string + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: order + in: query + description: *pagination_order_param_description + schema: + type: string + default: desc + enum: ["asc", "desc"] + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + - name: filter + in: query + description: "Filter by file status. One of `in_progress`, `completed`, `failed`, `cancelled`." + schema: + type: string + enum: ["in_progress", "completed", "failed", "cancelled"] + responses: + "200": + description: OK + content: + application/json: + schema: + $ref: "#/components/schemas/ListVectorStoreFilesResponse" + x-oaiMeta: + name: List vector store files in a batch + group: vector_stores + beta: true + returns: A list of [vector store file](/docs/api-reference/vector-stores-files/file-object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/vector_stores/vs_abc123/files_batches/vsfb_abc123/files \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -H "OpenAI-Beta: assistants=v2" + python: | + from openai import OpenAI + client = OpenAI() + + vector_store_files = client.beta.vector_stores.file_batches.list_files( + vector_store_id="vs_abc123", + batch_id="vsfb_abc123" + ) + print(vector_store_files) + node.js: | + import OpenAI from "openai"; + const openai = new OpenAI(); + + async function main() { + const vectorStoreFiles = await openai.beta.vectorStores.fileBatches.listFiles( + "vs_abc123", + "vsfb_abc123" + ); + console.log(vectorStoreFiles); + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "file-abc123", + "object": "vector_store.file", + "created_at": 1699061776, + "vector_store_id": "vs_abc123" + }, + { + "id": "file-abc456", + "object": "vector_store.file", + "created_at": 1699061776, + "vector_store_id": "vs_abc123" + } + ], + "first_id": "file-abc123", + "last_id": "file-abc456", + "has_more": false + } + + /batches: + post: + summary: Creates and executes a batch from an uploaded file of requests + operationId: createBatch + tags: + - Batch + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - input_file_id + - endpoint + - completion_window + properties: + input_file_id: + type: string + description: | + The ID of an uploaded file that contains requests for the new batch. + + See [upload file](/docs/api-reference/files/create) for how to upload a file. + + Your input file must be formatted as a [JSONL file](/docs/api-reference/batch/request-input), and must be uploaded with the purpose `batch`. The file can contain up to 50,000 requests, and can be up to 100 MB in size. + endpoint: + type: string + enum: + [ + "/v1/chat/completions", + "/v1/embeddings", + "/v1/completions", + ] + description: The endpoint to be used for all requests in the batch. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported. Note that `/v1/embeddings` batches are also restricted to a maximum of 50,000 embedding inputs across all requests in the batch. + completion_window: + type: string + enum: ["24h"] + description: The time frame within which the batch should be processed. Currently only `24h` is supported. + metadata: + type: object + additionalProperties: + type: string + description: Optional custom metadata for the batch. + nullable: true + responses: + "200": + description: Batch created successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Batch" + x-oaiMeta: + name: Create batch + group: batch + returns: The created [Batch](/docs/api-reference/batch/object) object. + examples: + request: + curl: | + curl https://api.openai.com/v1/batches \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "input_file_id": "file-abc123", + "endpoint": "/v1/chat/completions", + "completion_window": "24h" + }' + python: | + from openai import OpenAI + client = OpenAI() + + client.batches.create( + input_file_id="file-abc123", + endpoint="/v1/chat/completions", + completion_window="24h" + ) + node: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const batch = await openai.batches.create({ + input_file_id: "file-abc123", + endpoint: "/v1/chat/completions", + completion_window: "24h" + }); + + console.log(batch); + } + + main(); + response: | + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "validating", + "output_file_id": null, + "error_file_id": null, + "created_at": 1711471533, + "in_progress_at": null, + "expires_at": null, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 0, + "completed": 0, + "failed": 0 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job", + } + } + get: + operationId: listBatches + tags: + - Batch + summary: List your organization's batches. + parameters: + - in: query + name: after + required: false + schema: + type: string + description: *pagination_after_param_description + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + responses: + "200": + description: Batch listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ListBatchesResponse" + x-oaiMeta: + name: List batch + group: batch + returns: A list of paginated [Batch](/docs/api-reference/batch/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/batches?limit=2 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" + python: | + from openai import OpenAI + client = OpenAI() + + client.batches.list() + node: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const list = await openai.batches.list(); + + for await (const batch of list) { + console.log(batch); + } + } + + main(); + response: | + { + "object": "list", + "data": [ + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly job", + } + }, + { ... }, + ], + "first_id": "batch_abc123", + "last_id": "batch_abc456", + "has_more": true + } + + /batches/{batch_id}: + get: + operationId: retrieveBatch + tags: + - Batch + summary: Retrieves a batch. + parameters: + - in: path + name: batch_id + required: true + schema: + type: string + description: The ID of the batch to retrieve. + responses: + "200": + description: Batch retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Batch" + x-oaiMeta: + name: Retrieve batch + group: batch + returns: The [Batch](/docs/api-reference/batch/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/batches/batch_abc123 \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + python: | + from openai import OpenAI + client = OpenAI() + + client.batches.retrieve("batch_abc123") + node: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const batch = await openai.batches.retrieve("batch_abc123"); + + console.log(batch); + } + + main(); + response: &batch_object | + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "completed", + "output_file_id": "file-cvaTdG", + "error_file_id": "file-HOWS94", + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": 1711493133, + "completed_at": 1711493163, + "failed_at": null, + "expired_at": null, + "cancelling_at": null, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 95, + "failed": 5 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job", + } + } + + /batches/{batch_id}/cancel: + post: + operationId: cancelBatch + tags: + - Batch + summary: Cancels an in-progress batch. The batch will be in status `cancelling` for up to 10 minutes, before changing to `cancelled`, where it will have partial results (if any) available in the output file. + parameters: + - in: path + name: batch_id + required: true + schema: + type: string + description: The ID of the batch to cancel. + responses: + "200": + description: Batch is cancelling. Returns the cancelling batch's details. + content: + application/json: + schema: + $ref: "#/components/schemas/Batch" + x-oaiMeta: + name: Cancel batch + group: batch + returns: The [Batch](/docs/api-reference/batch/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/batches/batch_abc123/cancel \ + -H "Authorization: Bearer $OPENAI_API_KEY" \ + -H "Content-Type: application/json" \ + -X POST + python: | + from openai import OpenAI + client = OpenAI() + + client.batches.cancel("batch_abc123") + node: | + import OpenAI from "openai"; + + const openai = new OpenAI(); + + async function main() { + const batch = await openai.batches.cancel("batch_abc123"); + + console.log(batch); + } + + main(); + response: | + { + "id": "batch_abc123", + "object": "batch", + "endpoint": "/v1/chat/completions", + "errors": null, + "input_file_id": "file-abc123", + "completion_window": "24h", + "status": "cancelling", + "output_file_id": null, + "error_file_id": null, + "created_at": 1711471533, + "in_progress_at": 1711471538, + "expires_at": 1711557933, + "finalizing_at": null, + "completed_at": null, + "failed_at": null, + "expired_at": null, + "cancelling_at": 1711475133, + "cancelled_at": null, + "request_counts": { + "total": 100, + "completed": 23, + "failed": 1 + }, + "metadata": { + "customer_id": "user_123456789", + "batch_description": "Nightly eval job", + } + } + + # Organization + # Audit Logs List + /organization/audit_logs: + get: + summary: List user actions and configuration changes within this organization. + operationId: list-audit-logs + tags: + - Audit Logs + parameters: + - name: effective_at + in: query + description: Return only events whose `effective_at` (Unix seconds) is in this range. + required: false + schema: + type: object + properties: + gt: + type: integer + description: Return only events whose `effective_at` (Unix seconds) is greater than this value. + gte: + type: integer + description: Return only events whose `effective_at` (Unix seconds) is greater than or equal to this value. + lt: + type: integer + description: Return only events whose `effective_at` (Unix seconds) is less than this value. + lte: + type: integer + description: Return only events whose `effective_at` (Unix seconds) is less than or equal to this value. + - name: project_ids[] + in: query + description: Return only events for these projects. + required: false + schema: + type: array + items: + type: string + - name: event_types[] + in: query + description: Return only events with a `type` in one of these values. For example, `project.created`. For all options, see the documentation for the [audit log object](/docs/api-reference/audit-logs/object). + required: false + schema: + type: array + items: + $ref: "#/components/schemas/AuditLogEventType" + - name: actor_ids[] + in: query + description: Return only events performed by these actors. Can be a user ID, a service account ID, or an api key tracking ID. + required: false + schema: + type: array + items: + type: string + - name: actor_emails[] + in: query + description: Return only events performed by users with these emails. + required: false + schema: + type: array + items: + type: string + - name: resource_ids[] + in: query + description: Return only events performed on these targets. For example, a project ID updated. + required: false + schema: + type: array + items: + type: string + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + schema: + type: string + - name: before + in: query + description: *pagination_before_param_description + schema: + type: string + responses: + "200": + description: Audit logs listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ListAuditLogsResponse" + x-oaiMeta: + name: List audit logs + group: audit-logs + returns: A list of paginated [Audit Log](/docs/api-reference/audit-logs/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/audit_logs \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + response: | + { + "object": "list", + "data": [ + { + "id": "audit_log-xxx_yyyymmdd", + "type": "project.archived", + "effective_at": 1722461446, + "actor": { + "type": "api_key", + "api_key": { + "type": "user", + "user": { + "id": "user-xxx", + "email": "user@example.com" + } + } + }, + "project.archived": { + "id": "proj_abc" + }, + }, + { + "id": "audit_log-yyy__20240101", + "type": "api_key.updated", + "effective_at": 1720804190, + "actor": { + "type": "session", + "session": { + "user": { + "id": "user-xxx", + "email": "user@example.com" + }, + "ip_address": "127.0.0.1", + "user_agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + } + }, + "api_key.updated": { + "id": "key_xxxx", + "data": { + "scopes": ["resource_2.operation_2"] + } + }, + } + ], + "first_id": "audit_log-xxx__20240101", + "last_id": "audit_log_yyy__20240101", + "has_more": true + } + /organization/invites: + get: + summary: Returns a list of invites in the organization. + operationId: list-invites + tags: + - Invites + parameters: + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + required: false + schema: + type: string + responses: + "200": + description: Invites listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/InviteListResponse" + x-oaiMeta: + name: List invites + group: administration + returns: A list of [Invite](/docs/api-reference/invite/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/invites?after=invite-abc&limit=20 \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "list", + "data": [ + { + "object": "organization.invite", + "id": "invite-abc", + "email": "user@example.com", + "role": "owner", + "status": "accepted", + "invited_at": 1711471533, + "expires_at": 1711471533, + "accepted_at": 1711471533 + } + ], + "first_id": "invite-abc", + "last_id": "invite-abc", + "has_more": false + } + + post: + summary: Create an invite for a user to the organization. The invite must be accepted by the user before they have access to the organization. + operationId: inviteUser + tags: + - Invites + requestBody: + description: The invite request payload. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/InviteRequest" + responses: + "200": + description: User invited successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Invite" + x-oaiMeta: + name: Create invite + group: administration + returns: The created [Invite](/docs/api-reference/invite/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/invites \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "email": "user@example.com", + "role": "owner" + }' + response: + content: | + { + "object": "organization.invite", + "id": "invite-abc", + "email": "user@example.com", + "role": "owner", + "invited_at": 1711471533, + "expires_at": 1711471533, + "accepted_at": null + } + + /organization/invites/{invite_id}: + get: + summary: Retrieves an invite. + operationId: retrieve-invite + tags: + - Invites + parameters: + - in: path + name: invite_id + required: true + schema: + type: string + description: The ID of the invite to retrieve. + responses: + "200": + description: Invite retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Invite" + x-oaiMeta: + name: Retrieve invite + group: administration + returns: The [Invite](/docs/api-reference/invite/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/invites/invite-abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.invite", + "id": "invite-abc", + "email": "user@example.com", + "role": "owner", + "status": "accepted", + "invited_at": 1711471533, + "expires_at": 1711471533, + "accepted_at": 1711471533 + } + delete: + summary: Delete an invite. If the invite has already been accepted, it cannot be deleted. + operationId: delete-invite + tags: + - Invites + parameters: + - in: path + name: invite_id + required: true + schema: + type: string + description: The ID of the invite to delete. + responses: + "200": + description: Invite deleted successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/InviteDeleteResponse" + x-oaiMeta: + name: Delete invite + group: administration + returns: Confirmation that the invite has been deleted + examples: + request: + curl: | + curl -X DELETE https://api.openai.com/v1/organization/invites/invite-abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.invite.deleted", + "id": "invite-abc", + "deleted": true + } + + /organization/users: + get: + summary: Lists all of the users in the organization. + operationId: list-users + tags: + - Users + parameters: + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + required: false + schema: + type: string + responses: + "200": + description: Users listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/UserListResponse" + x-oaiMeta: + name: List users + group: administration + returns: A list of [User](/docs/api-reference/users/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/users?after=user_abc&limit=20 \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "list", + "data": [ + { + "object": "organization.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + ], + "first_id": "user-abc", + "last_id": "user-xyz", + "has_more": false + } + + /organization/users/{user_id}: + get: + summary: Retrieves a user by their identifier. + operationId: retrieve-user + tags: + - Users + parameters: + - name: user_id + in: path + description: The ID of the user. + required: true + schema: + type: string + responses: + "200": + description: User retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/User" + x-oaiMeta: + name: Retrieve user + group: administration + returns: The [User](/docs/api-reference/users/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/users/user_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + + post: + summary: Modifies a user's role in the organization. + operationId: modify-user + tags: + - Users + requestBody: + description: The new user role to modify. This must be one of `owner` or `member`. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/UserRoleUpdateRequest" + responses: + "200": + description: User role updated successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/User" + x-oaiMeta: + name: Modify user + group: administration + returns: The updated [User](/docs/api-reference/users/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/users/user_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "role": "owner" + }' + response: + content: | + { + "object": "organization.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + + delete: + summary: Deletes a user from the organization. + operationId: delete-user + tags: + - Users + parameters: + - name: user_id + in: path + description: The ID of the user. + required: true + schema: + type: string + responses: + "200": + description: User deleted successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/UserDeleteResponse" + x-oaiMeta: + name: Delete user + group: administration + returns: Confirmation of the deleted user + examples: + request: + curl: | + curl -X DELETE https://api.openai.com/v1/organization/users/user_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.user.deleted", + "id": "user_abc", + "deleted": true + } + /organization/projects: + get: + summary: Returns a list of projects. + operationId: list-projects + tags: + - Projects + parameters: + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + required: false + schema: + type: string + - name: include_archived + in: query + schema: + type: boolean + default: false + description: If `true` returns all projects including those that have been `archived`. Archived projects are not included by default. + responses: + "200": + description: Projects listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectListResponse" + x-oaiMeta: + name: List projects + group: administration + returns: A list of [Project](/docs/api-reference/projects/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects?after=proj_abc&limit=20&include_archived=false \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "list", + "data": [ + { + "id": "proj_abc", + "object": "organization.project", + "name": "Project example", + "created_at": 1711471533, + "archived_at": null, + "status": "active" + } + ], + "first_id": "proj-abc", + "last_id": "proj-xyz", + "has_more": false + } + + post: + summary: Create a new project in the organization. Projects can be created and archived, but cannot be deleted. + operationId: create-project + tags: + - Projects + requestBody: + description: The project create request payload. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectCreateRequest" + responses: + "200": + description: Project created successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Project" + x-oaiMeta: + name: Create project + group: administration + returns: The created [Project](/docs/api-reference/projects/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/projects \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Project ABC" + }' + response: + content: | + { + "id": "proj_abc", + "object": "organization.project", + "name": "Project ABC", + "created_at": 1711471533, + "archived_at": null, + "status": "active" + } + + /organization/projects/{project_id}: + get: + summary: Retrieves a project. + operationId: retrieve-project + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + responses: + "200": + description: Project retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Project" + x-oaiMeta: + name: Retrieve project + group: administration + description: Retrieve a project. + returns: The [Project](/docs/api-reference/projects/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "id": "proj_abc", + "object": "organization.project", + "name": "Project example", + "created_at": 1711471533, + "archived_at": null, + "status": "active" + } + + post: + summary: Modifies a project in the organization. + operationId: modify-project + tags: + - Projects + requestBody: + description: The project update request payload. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUpdateRequest" + responses: + "200": + description: Project updated successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Project" + "400": + description: Error response when updating the default project. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: Modify project + group: administration + returns: The updated [Project](/docs/api-reference/projects/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/projects/proj_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Project DEF" + }' + + /organization/projects/{project_id}/archive: + post: + summary: Archives a project in the organization. Archived projects cannot be used or updated. + operationId: archive-project + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + responses: + "200": + description: Project archived successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/Project" + x-oaiMeta: + name: Archive project + group: administration + returns: The archived [Project](/docs/api-reference/projects/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/projects/proj_abc/archive \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "id": "proj_abc", + "object": "organization.project", + "name": "Project DEF", + "created_at": 1711471533, + "archived_at": 1711471533, + "status": "archived" + } + + /organization/projects/{project_id}/users: + get: + summary: Returns a list of users in the project. + operationId: list-project-users + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + required: false + schema: + type: string + responses: + "200": + description: Project users listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUserListResponse" + "400": + description: Error response when project is archived. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: List project users + group: administration + returns: A list of [ProjectUser](/docs/api-reference/project-users/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc/users?after=user_abc&limit=20 \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "list", + "data": [ + { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + ], + "first_id": "user-abc", + "last_id": "user-xyz", + "has_more": false + } + error_response: + content: | + { + "code": 400, + "message": "Project {name} is archived" + } + + post: + summary: Adds a user to the project. Users must already be members of the organization to be added to a project. + operationId: create-project-user + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + tags: + - Projects + requestBody: + description: The project user create request payload. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUserCreateRequest" + responses: + "200": + description: User added to project successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUser" + "400": + description: Error response for various conditions. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: Create project user + group: administration + returns: The created [ProjectUser](/docs/api-reference/project-users/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/projects/proj_abc/users \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "user_id": "user_abc", + "role": "member" + }' + response: + content: | + { + "object": "organization.project.user", + "id": "user_abc", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + error_response: + content: | + { + "code": 400, + "message": "Project {name} is archived" + } + + /organization/projects/{project_id}/users/{user_id}: + get: + summary: Retrieves a user in the project. + operationId: retrieve-project-user + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: user_id + in: path + description: The ID of the user. + required: true + schema: + type: string + responses: + "200": + description: Project user retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUser" + x-oaiMeta: + name: Retrieve project user + group: administration + returns: The [ProjectUser](/docs/api-reference/project-users/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc/users/user_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + + post: + summary: Modifies a user's role in the project. + operationId: modify-project-user + tags: + - Projects + requestBody: + description: The project user update request payload. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUserUpdateRequest" + responses: + "200": + description: Project user's role updated successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUser" + "400": + description: Error response for various conditions. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: Modify project user + group: administration + returns: The updated [ProjectUser](/docs/api-reference/project-users/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/projects/proj_abc/users/user_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "role": "owner" + }' + response: + content: | + { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + + delete: + summary: Deletes a user from the project. + operationId: delete-project-user + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: user_id + in: path + description: The ID of the user. + required: true + schema: + type: string + responses: + "200": + description: Project user deleted successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectUserDeleteResponse" + "400": + description: Error response for various conditions. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: Delete project user + group: administration + returns: Confirmation that project has been deleted or an error in case of an archived project, which has no users + examples: + request: + curl: | + curl -X DELETE https://api.openai.com/v1/organization/projects/proj_abc/users/user_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.project.user.deleted", + "id": "user_abc", + "deleted": true + } + + /organization/projects/{project_id}/service_accounts: + get: + summary: Returns a list of service accounts in the project. + operationId: list-project-service-accounts + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + required: false + schema: + type: string + responses: + "200": + description: Project service accounts listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectServiceAccountListResponse" + "400": + description: Error response when project is archived. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: List project service accounts + group: administration + returns: A list of [ProjectServiceAccount](/docs/api-reference/project-service-accounts/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc/service_accounts?after=custom_id&limit=20 \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "list", + "data": [ + { + "object": "organization.project.service_account", + "id": "svc_acct_abc", + "name": "Service Account", + "role": "owner", + "created_at": 1711471533 + } + ], + "first_id": "svc_acct_abc", + "last_id": "svc_acct_xyz", + "has_more": false + } + + post: + summary: Creates a new service account in the project. This also returns an unredacted API key for the service account. + operationId: create-project-service-account + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + requestBody: + description: The project service account create request payload. + required: true + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectServiceAccountCreateRequest" + responses: + "200": + description: Project service account created successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectServiceAccountCreateResponse" + "400": + description: Error response when project is archived. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: Create project service account + group: administration + returns: The created [ProjectServiceAccount](/docs/api-reference/project-service-accounts/object) object. + examples: + request: + curl: | + curl -X POST https://api.openai.com/v1/organization/projects/proj_abc/service_accounts \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" \ + -d '{ + "name": "Production App" + }' + response: + content: | + { + "object": "organization.project.service_account", + "id": "svc_acct_abc", + "name": "Production App", + "role": "member", + "created_at": 1711471533, + "api_key": { + "object": "organization.project.service_account.api_key", + "value": "sk-abcdefghijklmnop123", + "name": "Secret Key", + "created_at": 1711471533, + "id": "key_abc" + } + } + + /organization/projects/{project_id}/service_accounts/{service_account_id}: + get: + summary: Retrieves a service account in the project. + operationId: retrieve-project-service-account + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: service_account_id + in: path + description: The ID of the service account. + required: true + schema: + type: string + responses: + "200": + description: Project service account retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectServiceAccount" + x-oaiMeta: + name: Retrieve project service account + group: administration + returns: The [ProjectServiceAccount](/docs/api-reference/project-service-accounts/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc/service_accounts/svc_acct_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.project.service_account", + "id": "svc_acct_abc", + "name": "Service Account", + "role": "owner", + "created_at": 1711471533 + } + + delete: + summary: Deletes a service account from the project. + operationId: delete-project-service-account + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: service_account_id + in: path + description: The ID of the service account. + required: true + schema: + type: string + responses: + "200": + description: Project service account deleted successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectServiceAccountDeleteResponse" + x-oaiMeta: + name: Delete project service account + group: administration + returns: Confirmation of service account being deleted, or an error in case of an archived project, which has no service accounts + examples: + request: + curl: | + curl -X DELETE https://api.openai.com/v1/organization/projects/proj_abc/service_accounts/svc_acct_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.project.service_account.deleted", + "id": "svc_acct_abc", + "deleted": true + } + + /organization/projects/{project_id}/api_keys: + get: + summary: Returns a list of API keys in the project. + operationId: list-project-api-keys + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: limit + in: query + description: *pagination_limit_param_description + required: false + schema: + type: integer + default: 20 + - name: after + in: query + description: *pagination_after_param_description + required: false + schema: + type: string + responses: + "200": + description: Project API keys listed successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectApiKeyListResponse" + + x-oaiMeta: + name: List project API keys + group: administration + returns: A list of [ProjectApiKey](/docs/api-reference/project-api-keys/object) objects. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc/api_keys?after=key_abc&limit=20 \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "list", + "data": [ + { + "object": "organization.project.api_key", + "redacted_value": "sk-abc...def", + "name": "My API Key", + "created_at": 1711471533, + "id": "key_abc", + "owner": { + "type": "user", + "user": { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + } + } + ], + "first_id": "key_abc", + "last_id": "key_xyz", + "has_more": false + } + error_response: + content: | + { + "code": 400, + "message": "Project {name} is archived" + } + + /organization/projects/{project_id}/api_keys/{key_id}: + get: + summary: Retrieves an API key in the project. + operationId: retrieve-project-api-key + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: key_id + in: path + description: The ID of the API key. + required: true + schema: + type: string + responses: + "200": + description: Project API key retrieved successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectApiKey" + x-oaiMeta: + name: Retrieve project API key + group: administration + returns: The [ProjectApiKey](/docs/api-reference/project-api-keys/object) object matching the specified ID. + examples: + request: + curl: | + curl https://api.openai.com/v1/organization/projects/proj_abc/api_keys/key_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.project.api_key", + "redacted_value": "sk-abc...def", + "name": "My API Key", + "created_at": 1711471533, + "id": "key_abc", + "owner": { + "type": "user", + "user": { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + } + } + + delete: + summary: Deletes an API key from the project. + operationId: delete-project-api-key + tags: + - Projects + parameters: + - name: project_id + in: path + description: The ID of the project. + required: true + schema: + type: string + - name: key_id + in: path + description: The ID of the API key. + required: true + schema: + type: string + responses: + "200": + description: Project API key deleted successfully. + content: + application/json: + schema: + $ref: "#/components/schemas/ProjectApiKeyDeleteResponse" + "400": + description: Error response for various conditions. + content: + application/json: + schema: + $ref: "#/components/schemas/ErrorResponse" + x-oaiMeta: + name: Delete project API key + group: administration + returns: Confirmation of the key's deletion or an error if the key belonged to a service account + examples: + request: + curl: | + curl -X DELETE https://api.openai.com/v1/organization/projects/proj_abc/api_keys/key_abc \ + -H "Authorization: Bearer $OPENAI_ADMIN_KEY" \ + -H "Content-Type: application/json" + response: + content: | + { + "object": "organization.project.api_key.deleted", + "id": "key_abc", + "deleted": true + } + error_response: + content: | + { + "code": 400, + "message": "API keys cannot be deleted for service accounts, please delete the service account" + } + +components: + securitySchemes: + ApiKeyAuth: + type: http + scheme: "bearer" + + schemas: + Error: + type: object + properties: + code: + type: string + nullable: true + message: + type: string + nullable: false + param: + type: string + nullable: true + type: + type: string + nullable: false + required: + - type + - message + - param + - code + ErrorResponse: + type: object + properties: + error: + $ref: "#/components/schemas/Error" + required: + - error + + ListModelsResponse: + type: object + properties: + object: + type: string + enum: [list] + data: + type: array + items: + $ref: "#/components/schemas/Model" + required: + - object + - data + DeleteModelResponse: + type: object + properties: + id: + type: string + deleted: + type: boolean + object: + type: string + required: + - id + - object + - deleted + + CreateCompletionRequest: + type: object + properties: + model: + description: &model_description | + ID of the model to use. You can use the [List models](/docs/api-reference/models/list) API to see all of your available models, or see our [Model overview](/docs/models/overview) for descriptions of them. + anyOf: + - type: string + - type: string + enum: ["gpt-3.5-turbo-instruct", "davinci-002", "babbage-002"] + x-oaiTypeLabel: string + prompt: + description: &completions_prompt_description | + The prompt(s) to generate completions for, encoded as a string, array of strings, array of tokens, or array of token arrays. + + Note that <|endoftext|> is the document separator that the model sees during training, so if a prompt is not specified the model will generate as if from the beginning of a new document. + default: "<|endoftext|>" + nullable: true + oneOf: + - type: string + default: "" + example: "This is a test." + - type: array + items: + type: string + default: "" + example: "This is a test." + - type: array + minItems: 1 + items: + type: integer + example: "[1212, 318, 257, 1332, 13]" + - type: array + minItems: 1 + items: + type: array + minItems: 1 + items: + type: integer + example: "[[1212, 318, 257, 1332, 13]]" + best_of: + type: integer + default: 1 + minimum: 0 + maximum: 20 + nullable: true + description: &completions_best_of_description | + Generates `best_of` completions server-side and returns the "best" (the one with the highest log probability per token). Results cannot be streamed. + + When used with `n`, `best_of` controls the number of candidate completions and `n` specifies how many to return – `best_of` must be greater than `n`. + + **Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`. + echo: + type: boolean + default: false + nullable: true + description: &completions_echo_description > + Echo back the prompt in addition to the completion + frequency_penalty: + type: number + default: 0 + minimum: -2 + maximum: 2 + nullable: true + description: &completions_frequency_penalty_description | + Number between -2.0 and 2.0. Positive values penalize new tokens based on their existing frequency in the text so far, decreasing the model's likelihood to repeat the same line verbatim. + + [See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details) + logit_bias: &completions_logit_bias + type: object + x-oaiTypeLabel: map + default: null + nullable: true + additionalProperties: + type: integer + description: &completions_logit_bias_description | + Modify the likelihood of specified tokens appearing in the completion. + + Accepts a JSON object that maps tokens (specified by their token ID in the GPT tokenizer) to an associated bias value from -100 to 100. You can use this [tokenizer tool](/tokenizer?view=bpe) to convert text to token IDs. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. + + As an example, you can pass `{"50256": -100}` to prevent the <|endoftext|> token from being generated. + logprobs: &completions_logprobs_configuration + type: integer + minimum: 0 + maximum: 5 + default: null + nullable: true + description: &completions_logprobs_description | + Include the log probabilities on the `logprobs` most likely output tokens, as well the chosen tokens. For example, if `logprobs` is 5, the API will return a list of the 5 most likely tokens. The API will always return the `logprob` of the sampled token, so there may be up to `logprobs+1` elements in the response. + + The maximum value for `logprobs` is 5. + max_tokens: + type: integer + minimum: 0 + default: 16 + example: 16 + nullable: true + description: &completions_max_tokens_description | + The maximum number of [tokens](/tokenizer) that can be generated in the completion. + + The token count of your prompt plus `max_tokens` cannot exceed the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens. + n: + type: integer + minimum: 1 + maximum: 128 + default: 1 + example: 1 + nullable: true + description: &completions_completions_description | + How many completions to generate for each prompt. + + **Note:** Because this parameter generates many completions, it can quickly consume your token quota. Use carefully and ensure that you have reasonable settings for `max_tokens` and `stop`. + presence_penalty: + type: number + default: 0 + minimum: -2 + maximum: 2 + nullable: true + description: &completions_presence_penalty_description | + Number between -2.0 and 2.0. Positive values penalize new tokens based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics. + + [See more information about frequency and presence penalties.](/docs/guides/text-generation/parameter-details) + seed: &completions_seed_param + type: integer + minimum: -9223372036854775808 + maximum: 9223372036854775807 + nullable: true + description: | + If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result. + + Determinism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend. + stop: + description: &completions_stop_description > + Up to 4 sequences where the API will stop generating further tokens. The returned text will not contain the stop sequence. + default: null + nullable: true + oneOf: + - type: string + default: <|endoftext|> + example: "\n" + nullable: true + - type: array + minItems: 1 + maxItems: 4 + items: + type: string + example: '["\n"]' + stream: + description: > + Whether to stream back partial progress. If set, tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) + as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions). + type: boolean + nullable: true + default: false + stream_options: + $ref: "#/components/schemas/ChatCompletionStreamOptions" + suffix: + description: | + The suffix that comes after a completion of inserted text. + + This parameter is only supported for `gpt-3.5-turbo-instruct`. + default: null + nullable: true + type: string + example: "test." + temperature: + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + description: &completions_temperature_description | + What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + + We generally recommend altering this or `top_p` but not both. + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: &completions_top_p_description | + An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + + We generally recommend altering this or `temperature` but not both. + user: &end_user_param_configuration + type: string + example: user-1234 + description: | + A unique identifier representing your end-user, which can help OpenAI to monitor and detect abuse. [Learn more](/docs/guides/safety-best-practices/end-user-ids). + required: + - model + - prompt + + CreateCompletionResponse: + type: object + description: | + Represents a completion response from the API. Note: both the streamed and non-streamed response objects share the same shape (unlike the chat endpoint). + properties: + id: + type: string + description: A unique identifier for the completion. + choices: + type: array + description: The list of completion choices the model generated for the input prompt. + items: + type: object + required: + - finish_reason + - index + - logprobs + - text + properties: + finish_reason: + type: string + description: &completion_finish_reason_description | + The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, + `length` if the maximum number of tokens specified in the request was reached, + or `content_filter` if content was omitted due to a flag from our content filters. + enum: ["stop", "length", "content_filter"] + nullable: true + index: + type: integer + logprobs: + type: object + nullable: true + properties: + text_offset: + type: array + items: + type: integer + token_logprobs: + type: array + items: + type: number + tokens: + type: array + items: + type: string + top_logprobs: + type: array + items: + type: object + additionalProperties: + type: number + text: + type: string + created: + type: integer + description: The Unix timestamp (in seconds) of when the completion was created. + model: + type: string + description: The model used for completion. + system_fingerprint: + type: string + description: | + This fingerprint represents the backend configuration that the model runs with. + + Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + object: + type: string + description: The object type, which is always "text_completion" + enum: [text_completion] + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - id + - object + - created + - model + - choices + x-oaiMeta: + name: The completion object + legacy: true + example: | + { + "id": "cmpl-uqkvlQyYK7bGYrRHQ0eXlWi7", + "object": "text_completion", + "created": 1589478378, + "model": "gpt-4-turbo", + "choices": [ + { + "text": "\n\nThis is indeed a test", + "index": 0, + "logprobs": null, + "finish_reason": "length" + } + ], + "usage": { + "prompt_tokens": 5, + "completion_tokens": 7, + "total_tokens": 12 + } + } + + ChatCompletionRequestMessageContentPartText: + type: object + title: Text content part + properties: + type: + type: string + enum: ["text"] + description: The type of the content part. + text: + type: string + description: The text content. + required: + - type + - text + + ChatCompletionRequestMessageContentPartImage: + type: object + title: Image content part + properties: + type: + type: string + enum: ["image_url"] + description: The type of the content part. + image_url: + type: object + properties: + url: + type: string + description: Either a URL of the image or the base64 encoded image data. + format: uri + detail: + type: string + description: Specifies the detail level of the image. Learn more in the [Vision guide](/docs/guides/vision/low-or-high-fidelity-image-understanding). + enum: ["auto", "low", "high"] + default: "auto" + required: + - url + required: + - type + - image_url + + ChatCompletionRequestMessageContentPartRefusal: + type: object + title: Refusal content part + properties: + type: + type: string + enum: ["refusal"] + description: The type of the content part. + refusal: + type: string + description: The refusal message generated by the model. + required: + - type + - refusal + + ChatCompletionRequestMessage: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestSystemMessage" + - $ref: "#/components/schemas/ChatCompletionRequestUserMessage" + - $ref: "#/components/schemas/ChatCompletionRequestAssistantMessage" + - $ref: "#/components/schemas/ChatCompletionRequestToolMessage" + - $ref: "#/components/schemas/ChatCompletionRequestFunctionMessage" + x-oaiExpandable: true + discriminator: + propertyName: role + mapping: + system: "#/components/schemas/ChatCompletionRequestSystemMessage" + user: "#/components/schemas/ChatCompletionRequestUserMessage" + assistant: "#/components/schemas/ChatCompletionRequestAssistantMessage" + tool: "#/components/schemas/ChatCompletionRequestToolMessage" + function: "#/components/schemas/ChatCompletionRequestFunctionMessage" + + ChatCompletionRequestSystemMessageContentPart: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartText" + x-oaiExpandable: true + + ChatCompletionRequestUserMessageContentPart: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartText" + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartImage" + x-oaiExpandable: true + + ChatCompletionRequestAssistantMessageContentPart: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartText" + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartRefusal" + x-oaiExpandable: true + + ChatCompletionRequestToolMessageContentPart: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestMessageContentPartText" + x-oaiExpandable: true + + ChatCompletionRequestSystemMessage: + type: object + title: System message + properties: + content: + description: The contents of the system message. + oneOf: + - type: string + description: The contents of the system message. + title: Text content + - type: array + description: An array of content parts with a defined type. For system messages, only type `text` is supported. + title: Array of content parts + items: + $ref: "#/components/schemas/ChatCompletionRequestSystemMessageContentPart" + minItems: 1 + role: + type: string + enum: ["system"] + description: The role of the messages author, in this case `system`. + name: + type: string + description: An optional name for the participant. Provides the model information to differentiate between participants of the same role. + required: + - content + - role + + ChatCompletionRequestUserMessage: + type: object + title: User message + properties: + content: + description: | + The contents of the user message. + oneOf: + - type: string + description: The text contents of the message. + title: Text content + - type: array + description: An array of content parts with a defined type, each can be of type `text` or `image_url` when passing in images. You can pass multiple images by adding multiple `image_url` content parts. Image input is only supported when using the `gpt-4o` model. + title: Array of content parts + items: + $ref: "#/components/schemas/ChatCompletionRequestUserMessageContentPart" + minItems: 1 + x-oaiExpandable: true + role: + type: string + enum: ["user"] + description: The role of the messages author, in this case `user`. + name: + type: string + description: An optional name for the participant. Provides the model information to differentiate between participants of the same role. + required: + - content + - role + + ChatCompletionRequestAssistantMessage: + type: object + title: Assistant message + properties: + content: + nullable: true + oneOf: + - type: string + description: The contents of the assistant message. + title: Text content + - type: array + description: An array of content parts with a defined type. Can be one or more of type `text`, or exactly one of type `refusal`. + title: Array of content parts + items: + $ref: "#/components/schemas/ChatCompletionRequestAssistantMessageContentPart" + minItems: 1 + description: | + The contents of the assistant message. Required unless `tool_calls` or `function_call` is specified. + refusal: + nullable: true + type: string + description: The refusal message by the assistant. + role: + type: string + enum: ["assistant"] + description: The role of the messages author, in this case `assistant`. + name: + type: string + description: An optional name for the participant. Provides the model information to differentiate between participants of the same role. + tool_calls: + $ref: "#/components/schemas/ChatCompletionMessageToolCalls" + function_call: + type: object + deprecated: true + description: "Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." + nullable: true + properties: + arguments: + type: string + description: The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + name: + type: string + description: The name of the function to call. + required: + - arguments + - name + required: + - role + + FineTuneChatCompletionRequestAssistantMessage: + allOf: + - type: object + title: Assistant message + deprecated: false + properties: + weight: + type: integer + enum: [0, 1] + description: "Controls whether the assistant message is trained against (0 or 1)" + - $ref: "#/components/schemas/ChatCompletionRequestAssistantMessage" + required: + - role + + ChatCompletionRequestToolMessage: + type: object + title: Tool message + properties: + role: + type: string + enum: ["tool"] + description: The role of the messages author, in this case `tool`. + content: + oneOf: + - type: string + description: The contents of the tool message. + title: Text content + - type: array + description: An array of content parts with a defined type. For tool messages, only type `text` is supported. + title: Array of content parts + items: + $ref: "#/components/schemas/ChatCompletionRequestToolMessageContentPart" + minItems: 1 + description: The contents of the tool message. + tool_call_id: + type: string + description: Tool call that this message is responding to. + required: + - role + - content + - tool_call_id + + ChatCompletionRequestFunctionMessage: + type: object + title: Function message + deprecated: true + properties: + role: + type: string + enum: ["function"] + description: The role of the messages author, in this case `function`. + content: + nullable: true + type: string + description: The contents of the function message. + name: + type: string + description: The name of the function to call. + required: + - role + - content + - name + + FunctionParameters: + type: object + description: "The parameters the functions accepts, described as a JSON Schema object. See the [guide](/docs/guides/function-calling) for examples, and the [JSON Schema reference](https://json-schema.org/understanding-json-schema/) for documentation about the format. \n\nOmitting `parameters` defines a function with an empty parameter list." + additionalProperties: true + + ChatCompletionFunctions: + type: object + deprecated: true + properties: + description: + type: string + description: A description of what the function does, used by the model to choose when and how to call the function. + name: + type: string + description: The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + parameters: + $ref: "#/components/schemas/FunctionParameters" + required: + - name + + ChatCompletionFunctionCallOption: + type: object + description: > + Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. + properties: + name: + type: string + description: The name of the function to call. + required: + - name + + ChatCompletionTool: + type: object + properties: + type: + type: string + enum: ["function"] + description: The type of the tool. Currently, only `function` is supported. + function: + $ref: "#/components/schemas/FunctionObject" + required: + - type + - function + + FunctionObject: + type: object + properties: + description: + type: string + description: A description of what the function does, used by the model to choose when and how to call the function. + name: + type: string + description: The name of the function to be called. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + parameters: + $ref: "#/components/schemas/FunctionParameters" + strict: + type: boolean + nullable: true + # default: false (TODO: dmchoi) revert once vllm updates their spec + description: Whether to enable strict schema adherence when generating the function call. If set to true, the model will follow the exact schema defined in the `parameters` field. Only a subset of JSON Schema is supported when `strict` is `true`. Learn more about Structured Outputs in the [function calling guide](docs/guides/function-calling). + required: + - name + + ResponseFormatText: + type: object + properties: + type: + type: string + description: "The type of response format being defined: `text`" + enum: ["text"] + required: + - type + + ResponseFormatJsonObject: + type: object + properties: + type: + type: string + description: "The type of response format being defined: `json_object`" + enum: ["json_object"] + required: + - type + + ResponseFormatJsonSchemaSchema: + type: object + description: "The schema for the response format, described as a JSON Schema object." + additionalProperties: true + + ResponseFormatJsonSchema: + type: object + properties: + type: + type: string + description: "The type of response format being defined: `json_schema`" + enum: ["json_schema"] + json_schema: + type: object + properties: + description: + type: string + description: A description of what the response format is for, used by the model to determine how to respond in the format. + name: + type: string + description: The name of the response format. Must be a-z, A-Z, 0-9, or contain underscores and dashes, with a maximum length of 64. + schema: + $ref: "#/components/schemas/ResponseFormatJsonSchemaSchema" + strict: + type: boolean + nullable: true + default: false + description: Whether to enable strict schema adherence when generating the output. If set to true, the model will always follow the exact schema defined in the `schema` field. Only a subset of JSON Schema is supported when `strict` is `true`. To learn more, read the [Structured Outputs guide](/docs/guides/structured-outputs). + required: + - type + - name + required: + - type + - json_schema + + ChatCompletionToolChoiceOption: + description: | + Controls which (if any) tool is called by the model. + `none` means the model will not call any tool and instead generates a message. + `auto` means the model can pick between generating a message or calling one or more tools. + `required` means the model must call one or more tools. + Specifying a particular tool via `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool. + + `none` is the default when no tools are present. `auto` is the default if tools are present. + oneOf: + - type: string + description: > + `none` means the model will not call any tool and instead generates a message. + `auto` means the model can pick between generating a message or calling one or more tools. + `required` means the model must call one or more tools. + enum: [none, auto, required] + - $ref: "#/components/schemas/ChatCompletionNamedToolChoice" + x-oaiExpandable: true + + ChatCompletionNamedToolChoice: + type: object + description: Specifies a tool the model should use. Use to force the model to call a specific function. + properties: + type: + type: string + enum: ["function"] + description: The type of the tool. Currently, only `function` is supported. + function: + type: object + properties: + name: + type: string + description: The name of the function to call. + required: + - name + required: + - type + - function + + ParallelToolCalls: + description: Whether to enable [parallel function calling](/docs/guides/function-calling/parallel-function-calling) during tool use. + type: boolean + default: true + + ChatCompletionMessageToolCalls: + type: array + description: The tool calls generated by the model, such as function calls. + items: + $ref: "#/components/schemas/ChatCompletionMessageToolCall" + + ChatCompletionMessageToolCall: + type: object + properties: + # TODO: index included when streaming + id: + type: string + description: The ID of the tool call. + type: + type: string + enum: ["function"] + description: The type of the tool. Currently, only `function` is supported. + function: + type: object + description: The function that the model called. + properties: + name: + type: string + description: The name of the function to call. + arguments: + type: string + description: The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + required: + - name + - arguments + required: + - id + - type + - function + + ChatCompletionMessageToolCallChunk: + type: object + properties: + index: + type: integer + id: + type: string + description: The ID of the tool call. + type: + type: string + enum: ["function"] + description: The type of the tool. Currently, only `function` is supported. + function: + type: object + properties: + name: + type: string + description: The name of the function to call. + arguments: + type: string + description: The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + required: + - index + + # Note, this isn't referenced anywhere, but is kept as a convenience to record all possible roles in one place. + ChatCompletionRole: + type: string + description: The role of the author of a message + enum: + - system + - user + - assistant + - tool + - function + + ChatCompletionStreamOptions: + description: | + Options for streaming response. Only set this when you set `stream: true`. + type: object + nullable: true + default: null + properties: + include_usage: + type: boolean + description: | + If set, an additional chunk will be streamed before the `data: [DONE]` message. The `usage` field on this chunk shows the token usage statistics for the entire request, and the `choices` field will always be an empty array. All other chunks will also include a `usage` field, but with a null value. + + ChatCompletionResponseMessage: + type: object + description: A chat completion message generated by the model. + properties: + content: + type: string + description: The contents of the message. + nullable: true + refusal: + type: string + description: The refusal message generated by the model. + nullable: true + tool_calls: + $ref: "#/components/schemas/ChatCompletionMessageToolCalls" + role: + type: string + enum: ["assistant"] + description: The role of the author of this message. + function_call: + type: object + deprecated: true + description: "Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." + properties: + arguments: + type: string + description: The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + name: + type: string + description: The name of the function to call. + required: + - name + - arguments + required: + - role + - content + # - refusal + + ChatCompletionStreamResponseDelta: + type: object + description: A chat completion delta generated by streamed model responses. + properties: + content: + type: string + description: The contents of the chunk message. + nullable: true + function_call: + deprecated: true + type: object + description: "Deprecated and replaced by `tool_calls`. The name and arguments of a function that should be called, as generated by the model." + properties: + arguments: + type: string + description: The arguments to call the function with, as generated by the model in JSON format. Note that the model does not always generate valid JSON, and may hallucinate parameters not defined by your function schema. Validate the arguments in your code before calling your function. + name: + type: string + description: The name of the function to call. + tool_calls: + type: array + items: + $ref: "#/components/schemas/ChatCompletionMessageToolCallChunk" + role: + type: string + enum: ["system", "user", "assistant", "tool"] + description: The role of the author of this message. + refusal: + type: string + description: The refusal message generated by the model. + nullable: true + + CreateChatCompletionRequest: + type: object + properties: + messages: + description: A list of messages comprising the conversation so far. [Example Python code](https://cookbook.openai.com/examples/how_to_format_inputs_to_chatgpt_models). + type: array + minItems: 1 + items: + $ref: "#/components/schemas/ChatCompletionRequestMessage" + model: + description: ID of the model to use. See the [model endpoint compatibility](/docs/models/model-endpoint-compatibility) table for details on which models work with the Chat API. + example: "gpt-4o" + anyOf: + - type: string + - type: string + enum: + [ + "gpt-4o", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "chatgpt-4o-latest", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0301", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ] + x-oaiTypeLabel: string + frequency_penalty: + type: number + default: 0 + minimum: -2 + maximum: 2 + nullable: true + description: *completions_frequency_penalty_description + logit_bias: + type: object + x-oaiTypeLabel: map + default: null + nullable: true + additionalProperties: + type: integer + description: | + Modify the likelihood of specified tokens appearing in the completion. + + Accepts a JSON object that maps tokens (specified by their token ID in the tokenizer) to an associated bias value from -100 to 100. Mathematically, the bias is added to the logits generated by the model prior to sampling. The exact effect will vary per model, but values between -1 and 1 should decrease or increase likelihood of selection; values like -100 or 100 should result in a ban or exclusive selection of the relevant token. + logprobs: + description: Whether to return log probabilities of the output tokens or not. If true, returns the log probabilities of each output token returned in the `content` of `message`. + type: boolean + default: false + nullable: true + top_logprobs: + description: An integer between 0 and 20 specifying the number of most likely tokens to return at each token position, each with an associated log probability. `logprobs` must be set to `true` if this parameter is used. + type: integer + minimum: 0 + maximum: 20 + nullable: true + max_tokens: + description: | + The maximum number of [tokens](/tokenizer) that can be generated in the chat completion. + + The total length of input tokens and generated tokens is limited by the model's context length. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens. + type: integer + nullable: true + n: + type: integer + minimum: 1 + maximum: 128 + default: 1 + example: 1 + nullable: true + description: How many chat completion choices to generate for each input message. Note that you will be charged based on the number of generated tokens across all of the choices. Keep `n` as `1` to minimize costs. + presence_penalty: + type: number + default: 0 + minimum: -2 + maximum: 2 + nullable: true + description: *completions_presence_penalty_description + response_format: + description: | + An object specifying the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4o mini](/docs/models/gpt-4o-mini), [GPT-4 Turbo](/docs/models/gpt-4-and-gpt-4-turbo) and all GPT-3.5 Turbo models newer than `gpt-3.5-turbo-1106`. + + Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs). + + Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON. + + **Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length. + oneOf: + - $ref: "#/components/schemas/ResponseFormatText" + - $ref: "#/components/schemas/ResponseFormatJsonObject" + - $ref: "#/components/schemas/ResponseFormatJsonSchema" + x-oaiExpandable: true + seed: + type: integer + minimum: -9223372036854775808 + maximum: 9223372036854775807 + nullable: true + description: | + This feature is in Beta. + If specified, our system will make a best effort to sample deterministically, such that repeated requests with the same `seed` and parameters should return the same result. + Determinism is not guaranteed, and you should refer to the `system_fingerprint` response parameter to monitor changes in the backend. + x-oaiMeta: + beta: true + service_tier: + description: | + Specifies the latency tier to use for processing the request. This parameter is relevant for customers subscribed to the scale tier service: + - If set to 'auto', the system will utilize scale tier credits until they are exhausted. + - If set to 'default', the request will be processed using the default service tier with a lower uptime SLA and no latency guarentee. + - When not set, the default behavior is 'auto'. + + When this parameter is set, the response body will include the `service_tier` utilized. + type: string + enum: ["auto", "default"] + nullable: true + default: null + stop: + description: | + Up to 4 sequences where the API will stop generating further tokens. + default: null + oneOf: + - type: string + nullable: true + - type: array + minItems: 1 + maxItems: 4 + items: + type: string + stream: + description: > + If set, partial message deltas will be sent, like in ChatGPT. Tokens will be sent as data-only [server-sent events](https://developer.mozilla.org/en-US/docs/Web/API/Server-sent_events/Using_server-sent_events#Event_stream_format) + as they become available, with the stream terminated by a `data: [DONE]` message. [Example Python code](https://cookbook.openai.com/examples/how_to_stream_completions). + type: boolean + nullable: true + default: false + stream_options: + $ref: "#/components/schemas/ChatCompletionStreamOptions" + temperature: + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + description: *completions_temperature_description + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: *completions_top_p_description + tools: + type: array + description: > + A list of tools the model may call. Currently, only functions are supported as a tool. + Use this to provide a list of functions the model may generate JSON inputs for. A max of 128 functions are supported. + items: + $ref: "#/components/schemas/ChatCompletionTool" + tool_choice: + $ref: "#/components/schemas/ChatCompletionToolChoiceOption" + parallel_tool_calls: + $ref: "#/components/schemas/ParallelToolCalls" + user: *end_user_param_configuration + function_call: + deprecated: true + description: | + Deprecated in favor of `tool_choice`. + + Controls which (if any) function is called by the model. + `none` means the model will not call a function and instead generates a message. + `auto` means the model can pick between generating a message or calling a function. + Specifying a particular function via `{"name": "my_function"}` forces the model to call that function. + + `none` is the default when no functions are present. `auto` is the default if functions are present. + oneOf: + - type: string + description: > + `none` means the model will not call a function and instead generates a message. + `auto` means the model can pick between generating a message or calling a function. + enum: [none, auto] + - $ref: "#/components/schemas/ChatCompletionFunctionCallOption" + x-oaiExpandable: true + functions: + deprecated: true + description: | + Deprecated in favor of `tools`. + + A list of functions the model may generate JSON inputs for. + type: array + minItems: 1 + maxItems: 128 + items: + $ref: "#/components/schemas/ChatCompletionFunctions" + + required: + - model + - messages + + CreateChatCompletionResponse: + type: object + description: Represents a chat completion response returned by model, based on the provided input. + properties: + id: + type: string + description: A unique identifier for the chat completion. + choices: + type: array + description: A list of chat completion choices. Can be more than one if `n` is greater than 1. + items: + type: object + required: + - finish_reason + - index + - message + # - logprobs + properties: + finish_reason: + type: string + description: &chat_completion_finish_reason_description | + The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, + `length` if the maximum number of tokens specified in the request was reached, + `content_filter` if content was omitted due to a flag from our content filters, + `tool_calls` if the model called a tool, or `function_call` (deprecated) if the model called a function. + enum: + [ + "stop", + "length", + "tool_calls", + "content_filter", + "function_call", + ] + index: + type: integer + description: The index of the choice in the list of choices. + message: + $ref: "#/components/schemas/ChatCompletionResponseMessage" + logprobs: &chat_completion_response_logprobs + description: Log probability information for the choice. + type: object + nullable: true + properties: + content: + description: A list of message content tokens with log probability information. + type: array + items: + $ref: "#/components/schemas/ChatCompletionTokenLogprob" + nullable: true + refusal: + description: A list of message refusal tokens with log probability information. + type: array + items: + $ref: "#/components/schemas/ChatCompletionTokenLogprob" + nullable: true + required: + - content + # - refusal + + created: + type: integer + description: The Unix timestamp (in seconds) of when the chat completion was created. + model: + type: string + description: The model used for the chat completion. + service_tier: + description: The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request. + type: string + enum: ["scale", "default"] + example: "scale" + nullable: true + system_fingerprint: + type: string + description: | + This fingerprint represents the backend configuration that the model runs with. + + Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + object: + type: string + description: The object type, which is always `chat.completion`. + enum: [chat.completion] + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - choices + - created + - id + - model + - object + x-oaiMeta: + name: The chat completion object + group: chat + example: *chat_completion_example + + CreateChatCompletionFunctionResponse: + type: object + description: Represents a chat completion response returned by model, based on the provided input. + properties: + id: + type: string + description: A unique identifier for the chat completion. + choices: + type: array + description: A list of chat completion choices. Can be more than one if `n` is greater than 1. + items: + type: object + required: + - finish_reason + - index + - message + - logprobs + properties: + finish_reason: + type: string + description: + &chat_completion_function_finish_reason_description | + The reason the model stopped generating tokens. This will be `stop` if the model hit a natural stop point or a provided stop sequence, `length` if the maximum number of tokens specified in the request was reached, `content_filter` if content was omitted due to a flag from our content filters, or `function_call` if the model called a function. + enum: ["stop", "length", "function_call", "content_filter"] + index: + type: integer + description: The index of the choice in the list of choices. + message: + $ref: "#/components/schemas/ChatCompletionResponseMessage" + created: + type: integer + description: The Unix timestamp (in seconds) of when the chat completion was created. + model: + type: string + description: The model used for the chat completion. + system_fingerprint: + type: string + description: | + This fingerprint represents the backend configuration that the model runs with. + + Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + object: + type: string + description: The object type, which is always `chat.completion`. + enum: [chat.completion] + usage: + $ref: "#/components/schemas/CompletionUsage" + required: + - choices + - created + - id + - model + - object + x-oaiMeta: + name: The chat completion object + group: chat + example: *chat_completion_function_example + + ChatCompletionTokenLogprob: + type: object + properties: + token: &chat_completion_response_logprobs_token + description: The token. + type: string + logprob: &chat_completion_response_logprobs_token_logprob + description: The log probability of this token, if it is within the top 20 most likely tokens. Otherwise, the value `-9999.0` is used to signify that the token is very unlikely. + type: number + bytes: &chat_completion_response_logprobs_bytes + description: A list of integers representing the UTF-8 bytes representation of the token. Useful in instances where characters are represented by multiple tokens and their byte representations must be combined to generate the correct text representation. Can be `null` if there is no bytes representation for the token. + type: array + items: + type: integer + nullable: true + top_logprobs: + description: List of the most likely tokens and their log probability, at this token position. In rare cases, there may be fewer than the number of requested `top_logprobs` returned. + type: array + items: + type: object + properties: + token: *chat_completion_response_logprobs_token + logprob: *chat_completion_response_logprobs_token_logprob + bytes: *chat_completion_response_logprobs_bytes + required: + - token + - logprob + - bytes + required: + - token + - logprob + - bytes + - top_logprobs + + ListPaginatedFineTuningJobsResponse: + type: object + properties: + data: + type: array + items: + $ref: "#/components/schemas/FineTuningJob" + has_more: + type: boolean + object: + type: string + enum: [list] + required: + - object + - data + - has_more + + CreateChatCompletionStreamResponse: + type: object + description: Represents a streamed chunk of a chat completion response returned by model, based on the provided input. + properties: + id: + type: string + description: A unique identifier for the chat completion. Each chunk has the same ID. + choices: + type: array + description: | + A list of chat completion choices. Can contain more than one elements if `n` is greater than 1. Can also be empty for the + last chunk if you set `stream_options: {"include_usage": true}`. + items: + type: object + required: + - delta + - finish_reason + - index + properties: + delta: + $ref: "#/components/schemas/ChatCompletionStreamResponseDelta" + logprobs: *chat_completion_response_logprobs + finish_reason: + type: string + description: *chat_completion_finish_reason_description + enum: + [ + "stop", + "length", + "tool_calls", + "content_filter", + "function_call", + ] + nullable: true + index: + type: integer + description: The index of the choice in the list of choices. + created: + type: integer + description: The Unix timestamp (in seconds) of when the chat completion was created. Each chunk has the same timestamp. + model: + type: string + description: The model to generate the completion. + service_tier: + description: The service tier used for processing the request. This field is only included if the `service_tier` parameter is specified in the request. + type: string + enum: ["scale", "default"] + example: "scale" + nullable: true + system_fingerprint: + type: string + description: | + This fingerprint represents the backend configuration that the model runs with. + Can be used in conjunction with the `seed` request parameter to understand when backend changes have been made that might impact determinism. + object: + type: string + description: The object type, which is always `chat.completion.chunk`. + enum: [chat.completion.chunk] + usage: + type: object + description: | + An optional field that will only be present when you set `stream_options: {"include_usage": true}` in your request. + When present, it contains a null value except for the last chunk which contains the token usage statistics for the entire request. + properties: + completion_tokens: + type: integer + description: Number of tokens in the generated completion. + prompt_tokens: + type: integer + description: Number of tokens in the prompt. + total_tokens: + type: integer + description: Total number of tokens used in the request (prompt + completion). + required: + - prompt_tokens + - completion_tokens + - total_tokens + required: + - choices + - created + - id + - model + - object + x-oaiMeta: + name: The chat completion chunk object + group: chat + example: *chat_completion_chunk_example + + CreateChatCompletionImageResponse: + type: object + description: Represents a streamed chunk of a chat completion response returned by model, based on the provided input. + x-oaiMeta: + name: The chat completion chunk object + group: chat + example: *chat_completion_image_example + + CreateImageRequest: + type: object + properties: + prompt: + description: A text description of the desired image(s). The maximum length is 1000 characters for `dall-e-2` and 4000 characters for `dall-e-3`. + type: string + example: "A cute baby sea otter" + model: + anyOf: + - type: string + - type: string + enum: ["dall-e-2", "dall-e-3"] + x-oaiTypeLabel: string + default: "dall-e-2" + example: "dall-e-3" + nullable: true + description: The model to use for image generation. + n: &images_n + type: integer + minimum: 1 + maximum: 10 + default: 1 + example: 1 + nullable: true + description: The number of images to generate. Must be between 1 and 10. For `dall-e-3`, only `n=1` is supported. + quality: + type: string + enum: ["standard", "hd"] + default: "standard" + example: "standard" + description: The quality of the image that will be generated. `hd` creates images with finer details and greater consistency across the image. This param is only supported for `dall-e-3`. + response_format: &images_response_format + type: string + enum: ["url", "b64_json"] + default: "url" + example: "url" + nullable: true + description: The format in which the generated images are returned. Must be one of `url` or `b64_json`. URLs are only valid for 60 minutes after the image has been generated. + size: &images_size + type: string + enum: ["256x256", "512x512", "1024x1024", "1792x1024", "1024x1792"] + default: "1024x1024" + example: "1024x1024" + nullable: true + description: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024` for `dall-e-2`. Must be one of `1024x1024`, `1792x1024`, or `1024x1792` for `dall-e-3` models. + style: + type: string + enum: ["vivid", "natural"] + default: "vivid" + example: "vivid" + nullable: true + description: The style of the generated images. Must be one of `vivid` or `natural`. Vivid causes the model to lean towards generating hyper-real and dramatic images. Natural causes the model to produce more natural, less hyper-real looking images. This param is only supported for `dall-e-3`. + user: *end_user_param_configuration + required: + - prompt + + ImagesResponse: + properties: + created: + type: integer + data: + type: array + items: + $ref: "#/components/schemas/Image" + required: + - created + - data + + Image: + type: object + description: Represents the url or the content of an image generated by the OpenAI API. + properties: + b64_json: + type: string + description: The base64-encoded JSON of the generated image, if `response_format` is `b64_json`. + url: + type: string + description: The URL of the generated image, if `response_format` is `url` (default). + revised_prompt: + type: string + description: The prompt that was used to generate the image, if there was any revision to the prompt. + x-oaiMeta: + name: The image object + example: | + { + "url": "...", + "revised_prompt": "..." + } + + CreateImageEditRequest: + type: object + properties: + image: + description: The image to edit. Must be a valid PNG file, less than 4MB, and square. If mask is not provided, image must have transparency, which will be used as the mask. + type: string + format: binary + prompt: + description: A text description of the desired image(s). The maximum length is 1000 characters. + type: string + example: "A cute baby sea otter wearing a beret" + mask: + description: An additional image whose fully transparent areas (e.g. where alpha is zero) indicate where `image` should be edited. Must be a valid PNG file, less than 4MB, and have the same dimensions as `image`. + type: string + format: binary + model: + anyOf: + - type: string + - type: string + enum: ["dall-e-2"] + x-oaiTypeLabel: string + default: "dall-e-2" + example: "dall-e-2" + nullable: true + description: The model to use for image generation. Only `dall-e-2` is supported at this time. + n: + type: integer + minimum: 1 + maximum: 10 + default: 1 + example: 1 + nullable: true + description: The number of images to generate. Must be between 1 and 10. + size: &dalle2_images_size + type: string + enum: ["256x256", "512x512", "1024x1024"] + default: "1024x1024" + example: "1024x1024" + nullable: true + description: The size of the generated images. Must be one of `256x256`, `512x512`, or `1024x1024`. + response_format: *images_response_format + user: *end_user_param_configuration + required: + - prompt + - image + + CreateImageVariationRequest: + type: object + properties: + image: + description: The image to use as the basis for the variation(s). Must be a valid PNG file, less than 4MB, and square. + type: string + format: binary + model: + anyOf: + - type: string + - type: string + enum: ["dall-e-2"] + x-oaiTypeLabel: string + default: "dall-e-2" + example: "dall-e-2" + nullable: true + description: The model to use for image generation. Only `dall-e-2` is supported at this time. + n: *images_n + response_format: *images_response_format + size: *dalle2_images_size + user: *end_user_param_configuration + required: + - image + + CreateModerationRequest: + type: object + properties: + input: + description: The input text to classify + oneOf: + - type: string + default: "" + example: "I want to kill them." + - type: array + items: + type: string + default: "" + example: "I want to kill them." + model: + description: | + Two content moderations models are available: `text-moderation-stable` and `text-moderation-latest`. + + The default is `text-moderation-latest` which will be automatically upgraded over time. This ensures you are always using our most accurate model. If you use `text-moderation-stable`, we will provide advanced notice before updating the model. Accuracy of `text-moderation-stable` may be slightly lower than for `text-moderation-latest`. + nullable: false + default: "text-moderation-latest" + example: "text-moderation-stable" + anyOf: + - type: string + - type: string + enum: ["text-moderation-latest", "text-moderation-stable"] + x-oaiTypeLabel: string + required: + - input + + CreateModerationResponse: + type: object + description: Represents if a given text input is potentially harmful. + properties: + id: + type: string + description: The unique identifier for the moderation request. + model: + type: string + description: The model used to generate the moderation results. + results: + type: array + description: A list of moderation objects. + items: + type: object + properties: + flagged: + type: boolean + description: Whether any of the below categories are flagged. + categories: + type: object + description: A list of the categories, and whether they are flagged or not. + properties: + hate: + type: boolean + description: Content that expresses, incites, or promotes hate based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste. Hateful content aimed at non-protected groups (e.g., chess players) is harassment. + hate/threatening: + type: boolean + description: Hateful content that also includes violence or serious harm towards the targeted group based on race, gender, ethnicity, religion, nationality, sexual orientation, disability status, or caste. + harassment: + type: boolean + description: Content that expresses, incites, or promotes harassing language towards any target. + harassment/threatening: + type: boolean + description: Harassment content that also includes violence or serious harm towards any target. + self-harm: + type: boolean + description: Content that promotes, encourages, or depicts acts of self-harm, such as suicide, cutting, and eating disorders. + self-harm/intent: + type: boolean + description: Content where the speaker expresses that they are engaging or intend to engage in acts of self-harm, such as suicide, cutting, and eating disorders. + self-harm/instructions: + type: boolean + description: Content that encourages performing acts of self-harm, such as suicide, cutting, and eating disorders, or that gives instructions or advice on how to commit such acts. + sexual: + type: boolean + description: Content meant to arouse sexual excitement, such as the description of sexual activity, or that promotes sexual services (excluding sex education and wellness). + sexual/minors: + type: boolean + description: Sexual content that includes an individual who is under 18 years old. + violence: + type: boolean + description: Content that depicts death, violence, or physical injury. + violence/graphic: + type: boolean + description: Content that depicts death, violence, or physical injury in graphic detail. + required: + - hate + - hate/threatening + - harassment + - harassment/threatening + - self-harm + - self-harm/intent + - self-harm/instructions + - sexual + - sexual/minors + - violence + - violence/graphic + category_scores: + type: object + description: A list of the categories along with their scores as predicted by model. + properties: + hate: + type: number + description: The score for the category 'hate'. + hate/threatening: + type: number + description: The score for the category 'hate/threatening'. + harassment: + type: number + description: The score for the category 'harassment'. + harassment/threatening: + type: number + description: The score for the category 'harassment/threatening'. + self-harm: + type: number + description: The score for the category 'self-harm'. + self-harm/intent: + type: number + description: The score for the category 'self-harm/intent'. + self-harm/instructions: + type: number + description: The score for the category 'self-harm/instructions'. + sexual: + type: number + description: The score for the category 'sexual'. + sexual/minors: + type: number + description: The score for the category 'sexual/minors'. + violence: + type: number + description: The score for the category 'violence'. + violence/graphic: + type: number + description: The score for the category 'violence/graphic'. + required: + - hate + - hate/threatening + - harassment + - harassment/threatening + - self-harm + - self-harm/intent + - self-harm/instructions + - sexual + - sexual/minors + - violence + - violence/graphic + required: + - flagged + - categories + - category_scores + required: + - id + - model + - results + x-oaiMeta: + name: The moderation object + example: *moderation_example + + ListFilesResponse: + type: object + properties: + data: + type: array + items: + $ref: "#/components/schemas/OpenAIFile" + object: + type: string + enum: [list] + required: + - object + - data + + CreateFileRequest: + type: object + additionalProperties: false + properties: + file: + description: | + The File object (not file name) to be uploaded. + type: string + format: binary + purpose: + description: | + The intended purpose of the uploaded file. + + Use "assistants" for [Assistants](/docs/api-reference/assistants) and [Message](/docs/api-reference/messages) files, "vision" for Assistants image file inputs, "batch" for [Batch API](/docs/guides/batch), and "fine-tune" for [Fine-tuning](/docs/api-reference/fine-tuning). + type: string + enum: ["assistants", "batch", "fine-tune", "vision"] + required: + - file + - purpose + + DeleteFileResponse: + type: object + properties: + id: + type: string + object: + type: string + enum: [file] + deleted: + type: boolean + required: + - id + - object + - deleted + + CreateUploadRequest: + type: object + additionalProperties: false + properties: + filename: + description: | + The name of the file to upload. + type: string + purpose: + description: | + The intended purpose of the uploaded file. + + See the [documentation on File purposes](/docs/api-reference/files/create#files-create-purpose). + type: string + enum: ["assistants", "batch", "fine-tune", "vision"] + bytes: + description: | + The number of bytes in the file you are uploading. + type: integer + mime_type: + description: | + The MIME type of the file. + + This must fall within the supported MIME types for your file purpose. See the supported MIME types for assistants and vision. + type: string + required: + - filename + - purpose + - bytes + - mime_type + + AddUploadPartRequest: + type: object + additionalProperties: false + properties: + data: + description: | + The chunk of bytes for this Part. + type: string + format: binary + required: + - data + + CompleteUploadRequest: + type: object + additionalProperties: false + properties: + part_ids: + type: array + description: | + The ordered list of Part IDs. + items: + type: string + md5: + description: | + The optional md5 checksum for the file contents to verify if the bytes uploaded matches what you expect. + type: string + required: + - part_ids + + CancelUploadRequest: + type: object + additionalProperties: false + + CreateFineTuningJobRequest: + type: object + properties: + model: + description: | + The name of the model to fine-tune. You can select one of the + [supported models](/docs/guides/fine-tuning/which-models-can-be-fine-tuned). + example: "gpt-4o-mini" + anyOf: + - type: string + - type: string + enum: + ["babbage-002", "davinci-002", "gpt-3.5-turbo", "gpt-4o-mini"] + x-oaiTypeLabel: string + training_file: + description: | + The ID of an uploaded file that contains training data. + + See [upload file](/docs/api-reference/files/create) for how to upload a file. + + Your dataset must be formatted as a JSONL file. Additionally, you must upload your file with the purpose `fine-tune`. + + The contents of the file should differ depending on if the model uses the [chat](/docs/api-reference/fine-tuning/chat-input) or [completions](/docs/api-reference/fine-tuning/completions-input) format. + + See the [fine-tuning guide](/docs/guides/fine-tuning) for more details. + type: string + example: "file-abc123" + hyperparameters: + type: object + description: The hyperparameters used for the fine-tuning job. + properties: + batch_size: + description: | + Number of examples in each batch. A larger batch size means that model parameters + are updated less frequently, but with lower variance. + oneOf: + - type: string + enum: [auto] + - type: integer + minimum: 1 + maximum: 256 + default: auto + learning_rate_multiplier: + description: | + Scaling factor for the learning rate. A smaller learning rate may be useful to avoid + overfitting. + oneOf: + - type: string + enum: [auto] + - type: number + minimum: 0 + exclusiveMinimum: true + default: auto + n_epochs: + description: | + The number of epochs to train the model for. An epoch refers to one full cycle + through the training dataset. + oneOf: + - type: string + enum: [auto] + - type: integer + minimum: 1 + maximum: 50 + default: auto + suffix: + description: | + A string of up to 18 characters that will be added to your fine-tuned model name. + + For example, a `suffix` of "custom-model-name" would produce a model name like `ft:gpt-4o-mini:openai:custom-model-name:7p4lURel`. + type: string + minLength: 1 + maxLength: 40 + default: null + nullable: true + validation_file: + description: | + The ID of an uploaded file that contains validation data. + + If you provide this file, the data is used to generate validation + metrics periodically during fine-tuning. These metrics can be viewed in + the fine-tuning results file. + The same data should not be present in both train and validation files. + + Your dataset must be formatted as a JSONL file. You must upload your file with the purpose `fine-tune`. + + See the [fine-tuning guide](/docs/guides/fine-tuning) for more details. + type: string + nullable: true + example: "file-abc123" + integrations: + type: array + description: A list of integrations to enable for your fine-tuning job. + nullable: true + items: + type: object + required: + - type + - wandb + properties: + type: + description: | + The type of integration to enable. Currently, only "wandb" (Weights and Biases) is supported. + oneOf: + - type: string + enum: [wandb] + wandb: + type: object + description: | + The settings for your integration with Weights and Biases. This payload specifies the project that + metrics will be sent to. Optionally, you can set an explicit display name for your run, add tags + to your run, and set a default entity (team, username, etc) to be associated with your run. + required: + - project + properties: + project: + description: | + The name of the project that the new run will be created under. + type: string + example: "my-wandb-project" + name: + description: | + A display name to set for the run. If not set, we will use the Job ID as the name. + nullable: true + type: string + entity: + description: | + The entity to use for the run. This allows you to set the team or username of the WandB user that you would + like associated with the run. If not set, the default entity for the registered WandB API key is used. + nullable: true + type: string + tags: + description: | + A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some + default tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". + type: array + items: + type: string + example: "custom-tag" + + seed: + description: | + The seed controls the reproducibility of the job. Passing in the same seed and job parameters should produce the same results, but may differ in rare cases. + If a seed is not specified, one will be generated for you. + type: integer + nullable: true + minimum: 0 + maximum: 2147483647 + example: 42 + required: + - model + - training_file + + ListFineTuningJobEventsResponse: + type: object + properties: + data: + type: array + items: + $ref: "#/components/schemas/FineTuningJobEvent" + object: + type: string + enum: [list] + required: + - object + - data + + ListFineTuningJobCheckpointsResponse: + type: object + properties: + data: + type: array + items: + $ref: "#/components/schemas/FineTuningJobCheckpoint" + object: + type: string + enum: [list] + first_id: + type: string + nullable: true + last_id: + type: string + nullable: true + has_more: + type: boolean + required: + - object + - data + - has_more + + CreateEmbeddingRequest: + type: object + additionalProperties: false + properties: + input: + description: | + Input text to embed, encoded as a string or array of tokens. To embed multiple inputs in a single request, pass an array of strings or array of token arrays. The input must not exceed the max input tokens for the model (8192 tokens for `text-embedding-ada-002`), cannot be an empty string, and any array must be 2048 dimensions or less. [Example Python code](https://cookbook.openai.com/examples/how_to_count_tokens_with_tiktoken) for counting tokens. + example: "The quick brown fox jumped over the lazy dog" + oneOf: + - type: string + title: string + description: The string that will be turned into an embedding. + default: "" + example: "This is a test." + - type: array + title: array + description: The array of strings that will be turned into an embedding. + minItems: 1 + maxItems: 2048 + items: + type: string + default: "" + example: "['This is a test.']" + - type: array + title: array + description: The array of integers that will be turned into an embedding. + minItems: 1 + maxItems: 2048 + items: + type: integer + example: "[1212, 318, 257, 1332, 13]" + - type: array + title: array + description: The array of arrays containing integers that will be turned into an embedding. + minItems: 1 + maxItems: 2048 + items: + type: array + minItems: 1 + items: + type: integer + example: "[[1212, 318, 257, 1332, 13]]" + x-oaiExpandable: true + model: + description: *model_description + example: "text-embedding-3-small" + anyOf: + - type: string + - type: string + enum: + [ + "text-embedding-ada-002", + "text-embedding-3-small", + "text-embedding-3-large", + ] + x-oaiTypeLabel: string + encoding_format: + description: "The format to return the embeddings in. Can be either `float` or [`base64`](https://pypi.org/project/pybase64/)." + example: "float" + default: "float" + type: string + enum: ["float", "base64"] + dimensions: + description: | + The number of dimensions the resulting output embeddings should have. Only supported in `text-embedding-3` and later models. + type: integer + minimum: 1 + user: *end_user_param_configuration + required: + - model + - input + + CreateEmbeddingResponse: + type: object + properties: + data: + type: array + description: The list of embeddings generated by the model. + items: + $ref: "#/components/schemas/Embedding" + model: + type: string + description: The name of the model used to generate the embedding. + object: + type: string + description: The object type, which is always "list". + enum: [list] + usage: + type: object + description: The usage information for the request. + properties: + prompt_tokens: + type: integer + description: The number of tokens used by the prompt. + total_tokens: + type: integer + description: The total number of tokens used by the request. + required: + - prompt_tokens + - total_tokens + required: + - object + - model + - data + - usage + + CreateTranscriptionRequest: + type: object + additionalProperties: false + properties: + file: + description: | + The audio file object (not file name) to transcribe, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + type: string + x-oaiTypeLabel: file + format: binary + model: + description: | + ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available. + example: whisper-1 + anyOf: + - type: string + - type: string + enum: ["whisper-1"] + x-oaiTypeLabel: string + language: + description: | + The language of the input audio. Supplying the input language in [ISO-639-1](https://en.wikipedia.org/wiki/List_of_ISO_639-1_codes) format will improve accuracy and latency. + type: string + prompt: + description: | + An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should match the audio language. + type: string + response_format: + description: | + The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`. + type: string + enum: + - json + - text + - srt + - verbose_json + - vtt + default: json + temperature: + description: | + The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. + type: number + default: 0 + timestamp_granularities[]: + description: | + The timestamp granularities to populate for this transcription. `response_format` must be set `verbose_json` to use timestamp granularities. Either or both of these options are supported: `word`, or `segment`. Note: There is no additional latency for segment timestamps, but generating word timestamps incurs additional latency. + type: array + items: + type: string + enum: + - word + - segment + default: [segment] + required: + - file + - model + + # Note: This does not currently support the non-default response format types. + CreateTranscriptionResponseJson: + type: object + description: Represents a transcription response returned by model, based on the provided input. + properties: + text: + type: string + description: The transcribed text. + required: + - text + x-oaiMeta: + name: The transcription object (JSON) + group: audio + example: *basic_transcription_response_example + + TranscriptionSegment: + type: object + properties: + id: + type: integer + description: Unique identifier of the segment. + seek: + type: integer + description: Seek offset of the segment. + start: + type: number + format: float + description: Start time of the segment in seconds. + end: + type: number + format: float + description: End time of the segment in seconds. + text: + type: string + description: Text content of the segment. + tokens: + type: array + items: + type: integer + description: Array of token IDs for the text content. + temperature: + type: number + format: float + description: Temperature parameter used for generating the segment. + avg_logprob: + type: number + format: float + description: Average logprob of the segment. If the value is lower than -1, consider the logprobs failed. + compression_ratio: + type: number + format: float + description: Compression ratio of the segment. If the value is greater than 2.4, consider the compression failed. + no_speech_prob: + type: number + format: float + description: Probability of no speech in the segment. If the value is higher than 1.0 and the `avg_logprob` is below -1, consider this segment silent. + required: + - id + - seek + - start + - end + - text + - tokens + - temperature + - avg_logprob + - compression_ratio + - no_speech_prob + + TranscriptionWord: + type: object + properties: + word: + type: string + description: The text content of the word. + start: + type: number + format: float + description: Start time of the word in seconds. + end: + type: number + format: float + description: End time of the word in seconds. + required: [word, start, end] + + CreateTranscriptionResponseVerboseJson: + type: object + description: Represents a verbose json transcription response returned by model, based on the provided input. + properties: + language: + type: string + description: The language of the input audio. + duration: + type: string + description: The duration of the input audio. + text: + type: string + description: The transcribed text. + words: + type: array + description: Extracted words and their corresponding timestamps. + items: + $ref: "#/components/schemas/TranscriptionWord" + segments: + type: array + description: Segments of the transcribed text and their corresponding details. + items: + $ref: "#/components/schemas/TranscriptionSegment" + required: [language, duration, text] + x-oaiMeta: + name: The transcription object (Verbose JSON) + group: audio + example: *verbose_transcription_response_example + + CreateTranslationRequest: + type: object + additionalProperties: false + properties: + file: + description: | + The audio file object (not file name) translate, in one of these formats: flac, mp3, mp4, mpeg, mpga, m4a, ogg, wav, or webm. + type: string + x-oaiTypeLabel: file + format: binary + model: + description: | + ID of the model to use. Only `whisper-1` (which is powered by our open source Whisper V2 model) is currently available. + example: whisper-1 + anyOf: + - type: string + - type: string + enum: ["whisper-1"] + x-oaiTypeLabel: string + prompt: + description: | + An optional text to guide the model's style or continue a previous audio segment. The [prompt](/docs/guides/speech-to-text/prompting) should be in English. + type: string + response_format: + description: | + The format of the transcript output, in one of these options: `json`, `text`, `srt`, `verbose_json`, or `vtt`. + type: string + default: json + temperature: + description: | + The sampling temperature, between 0 and 1. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. If set to 0, the model will use [log probability](https://en.wikipedia.org/wiki/Log_probability) to automatically increase the temperature until certain thresholds are hit. + type: number + default: 0 + required: + - file + - model + + # Note: This does not currently support the non-default response format types. + CreateTranslationResponseJson: + type: object + properties: + text: + type: string + required: + - text + + CreateTranslationResponseVerboseJson: + type: object + properties: + language: + type: string + description: The language of the output translation (always `english`). + duration: + type: string + description: The duration of the input audio. + text: + type: string + description: The translated text. + segments: + type: array + description: Segments of the translated text and their corresponding details. + items: + $ref: "#/components/schemas/TranscriptionSegment" + required: [language, duration, text] + + CreateSpeechRequest: + type: object + additionalProperties: false + properties: + model: + description: | + One of the available [TTS models](/docs/models/tts): `tts-1` or `tts-1-hd` + anyOf: + - type: string + - type: string + enum: ["tts-1", "tts-1-hd"] + x-oaiTypeLabel: string + input: + type: string + description: The text to generate audio for. The maximum length is 4096 characters. + maxLength: 4096 + voice: + description: The voice to use when generating the audio. Supported voices are `alloy`, `echo`, `fable`, `onyx`, `nova`, and `shimmer`. Previews of the voices are available in the [Text to speech guide](/docs/guides/text-to-speech/voice-options). + type: string + enum: ["alloy", "echo", "fable", "onyx", "nova", "shimmer"] + response_format: + description: "The format to audio in. Supported formats are `mp3`, `opus`, `aac`, `flac`, `wav`, and `pcm`." + default: "mp3" + type: string + enum: ["mp3", "opus", "aac", "flac", "wav", "pcm"] + speed: + description: "The speed of the generated audio. Select a value from `0.25` to `4.0`. `1.0` is the default." + type: number + default: 1.0 + minimum: 0.25 + maximum: 4.0 + required: + - model + - input + - voice + + Model: + title: Model + description: Describes an OpenAI model offering that can be used with the API. + properties: + id: + type: string + description: The model identifier, which can be referenced in the API endpoints. + created: + type: integer + description: The Unix timestamp (in seconds) when the model was created. + object: + type: string + description: The object type, which is always "model". + enum: [model] + owned_by: + type: string + description: The organization that owns the model. + required: + - id + - object + - created + - owned_by + x-oaiMeta: + name: The model object + example: *retrieve_model_response + + OpenAIFile: + title: OpenAIFile + description: The `File` object represents a document that has been uploaded to OpenAI. + properties: + id: + type: string + description: The file identifier, which can be referenced in the API endpoints. + bytes: + type: integer + description: The size of the file, in bytes. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the file was created. + filename: + type: string + description: The name of the file. + object: + type: string + description: The object type, which is always `file`. + enum: ["file"] + purpose: + type: string + description: The intended purpose of the file. Supported values are `assistants`, `assistants_output`, `batch`, `batch_output`, `fine-tune`, `fine-tune-results` and `vision`. + enum: + [ + "assistants", + "assistants_output", + "batch", + "batch_output", + "fine-tune", + "fine-tune-results", + "vision", + ] + status: + type: string + deprecated: true + description: Deprecated. The current status of the file, which can be either `uploaded`, `processed`, or `error`. + enum: ["uploaded", "processed", "error"] + status_details: + type: string + deprecated: true + description: Deprecated. For details on why a fine-tuning training file failed validation, see the `error` field on `fine_tuning.job`. + required: + - id + - object + - bytes + - created_at + - filename + - purpose + - status + x-oaiMeta: + name: The file object + example: | + { + "id": "file-abc123", + "object": "file", + "bytes": 120000, + "created_at": 1677610602, + "filename": "salesOverview.pdf", + "purpose": "assistants", + } + Upload: + type: object + title: Upload + description: | + The Upload object can accept byte chunks in the form of Parts. + properties: + id: + type: string + description: The Upload unique identifier, which can be referenced in API endpoints. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the Upload was created. + filename: + type: string + description: The name of the file to be uploaded. + bytes: + type: integer + description: The intended number of bytes to be uploaded. + purpose: + type: string + description: The intended purpose of the file. [Please refer here](/docs/api-reference/files/object#files/object-purpose) for acceptable values. + status: + type: string + description: The status of the Upload. + enum: ["pending", "completed", "cancelled", "expired"] + expires_at: + type: integer + description: The Unix timestamp (in seconds) for when the Upload was created. + object: + type: string + description: The object type, which is always "upload". + enum: [upload] + file: + $ref: "#/components/schemas/OpenAIFile" + nullable: true + description: The ready File object after the Upload is completed. + required: + - bytes + - created_at + - expires_at + - filename + - id + - purpose + - status + - step_number + x-oaiMeta: + name: The upload object + example: | + { + "id": "upload_abc123", + "object": "upload", + "bytes": 2147483648, + "created_at": 1719184911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + "status": "completed", + "expires_at": 1719127296, + "file": { + "id": "file-xyz321", + "object": "file", + "bytes": 2147483648, + "created_at": 1719186911, + "filename": "training_examples.jsonl", + "purpose": "fine-tune", + } + } + UploadPart: + type: object + title: UploadPart + description: | + The upload Part represents a chunk of bytes we can add to an Upload object. + properties: + id: + type: string + description: The upload Part unique identifier, which can be referenced in API endpoints. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the Part was created. + upload_id: + type: string + description: The ID of the Upload object that this Part was added to. + object: + type: string + description: The object type, which is always `upload.part`. + enum: ["upload.part"] + required: + - created_at + - id + - object + - upload_id + x-oaiMeta: + name: The upload part object + example: | + { + "id": "part_def456", + "object": "upload.part", + "created_at": 1719186911, + "upload_id": "upload_abc123" + } + Embedding: + type: object + description: | + Represents an embedding vector returned by embedding endpoint. + properties: + index: + type: integer + description: The index of the embedding in the list of embeddings. + embedding: + type: array + description: | + The embedding vector, which is a list of floats. The length of vector depends on the model as listed in the [embedding guide](/docs/guides/embeddings). + items: + type: number + object: + type: string + description: The object type, which is always "embedding". + enum: [embedding] + required: + - index + - object + - embedding + x-oaiMeta: + name: The embedding object + example: | + { + "object": "embedding", + "embedding": [ + 0.0023064255, + -0.009327292, + .... (1536 floats total for ada-002) + -0.0028842222, + ], + "index": 0 + } + + FineTuningJob: + type: object + title: FineTuningJob + description: | + The `fine_tuning.job` object represents a fine-tuning job that has been created through the API. + properties: + id: + type: string + description: The object identifier, which can be referenced in the API endpoints. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the fine-tuning job was created. + error: + type: object + nullable: true + description: For fine-tuning jobs that have `failed`, this will contain more information on the cause of the failure. + properties: + code: + type: string + description: A machine-readable error code. + message: + type: string + description: A human-readable error message. + param: + type: string + description: The parameter that was invalid, usually `training_file` or `validation_file`. This field will be null if the failure was not parameter-specific. + nullable: true + required: + - code + - message + - param + fine_tuned_model: + type: string + nullable: true + description: The name of the fine-tuned model that is being created. The value will be null if the fine-tuning job is still running. + finished_at: + type: integer + nullable: true + description: The Unix timestamp (in seconds) for when the fine-tuning job was finished. The value will be null if the fine-tuning job is still running. + hyperparameters: + type: object + description: The hyperparameters used for the fine-tuning job. See the [fine-tuning guide](/docs/guides/fine-tuning) for more details. + properties: + n_epochs: + oneOf: + - type: string + enum: [auto] + - type: integer + minimum: 1 + maximum: 50 + default: auto + description: + The number of epochs to train the model for. An epoch refers to one full cycle through the training dataset. + + "auto" decides the optimal number of epochs based on the size of the dataset. If setting the number manually, we support any number between 1 and 50 epochs. + required: + - n_epochs + model: + type: string + description: The base model that is being fine-tuned. + object: + type: string + description: The object type, which is always "fine_tuning.job". + enum: [fine_tuning.job] + organization_id: + type: string + description: The organization that owns the fine-tuning job. + result_files: + type: array + description: The compiled results file ID(s) for the fine-tuning job. You can retrieve the results with the [Files API](/docs/api-reference/files/retrieve-contents). + items: + type: string + example: file-abc123 + status: + type: string + description: The current status of the fine-tuning job, which can be either `validating_files`, `queued`, `running`, `succeeded`, `failed`, or `cancelled`. + enum: + [ + "validating_files", + "queued", + "running", + "succeeded", + "failed", + "cancelled", + ] + trained_tokens: + type: integer + nullable: true + description: The total number of billable tokens processed by this fine-tuning job. The value will be null if the fine-tuning job is still running. + training_file: + type: string + description: The file ID used for training. You can retrieve the training data with the [Files API](/docs/api-reference/files/retrieve-contents). + validation_file: + type: string + nullable: true + description: The file ID used for validation. You can retrieve the validation results with the [Files API](/docs/api-reference/files/retrieve-contents). + integrations: + type: array + nullable: true + description: A list of integrations to enable for this fine-tuning job. + maxItems: 5 + items: + oneOf: + - $ref: "#/components/schemas/FineTuningIntegration" + x-oaiExpandable: true + seed: + type: integer + description: The seed used for the fine-tuning job. + estimated_finish: + type: integer + nullable: true + description: The Unix timestamp (in seconds) for when the fine-tuning job is estimated to finish. The value will be null if the fine-tuning job is not running. + required: + - created_at + - error + - finished_at + - fine_tuned_model + - hyperparameters + - id + - model + - object + - organization_id + - result_files + - status + - trained_tokens + - training_file + - validation_file + - seed + x-oaiMeta: + name: The fine-tuning job object + example: *fine_tuning_example + + FineTuningIntegration: + type: object + title: Fine-Tuning Job Integration + required: + - type + - wandb + properties: + type: + type: string + description: "The type of the integration being enabled for the fine-tuning job" + enum: ["wandb"] + wandb: + type: object + description: | + The settings for your integration with Weights and Biases. This payload specifies the project that + metrics will be sent to. Optionally, you can set an explicit display name for your run, add tags + to your run, and set a default entity (team, username, etc) to be associated with your run. + required: + - project + properties: + project: + description: | + The name of the project that the new run will be created under. + type: string + example: "my-wandb-project" + name: + description: | + A display name to set for the run. If not set, we will use the Job ID as the name. + nullable: true + type: string + entity: + description: | + The entity to use for the run. This allows you to set the team or username of the WandB user that you would + like associated with the run. If not set, the default entity for the registered WandB API key is used. + nullable: true + type: string + tags: + description: | + A list of tags to be attached to the newly created run. These tags are passed through directly to WandB. Some + default tags are generated by OpenAI: "openai/finetune", "openai/{base-model}", "openai/{ftjob-abcdef}". + type: array + items: + type: string + example: "custom-tag" + + FineTuningJobEvent: + type: object + description: Fine-tuning job event object + properties: + id: + type: string + created_at: + type: integer + level: + type: string + enum: ["info", "warn", "error"] + message: + type: string + object: + type: string + enum: [fine_tuning.job.event] + required: + - id + - object + - created_at + - level + - message + x-oaiMeta: + name: The fine-tuning job event object + example: | + { + "object": "fine_tuning.job.event", + "id": "ftevent-abc123" + "created_at": 1677610602, + "level": "info", + "message": "Created fine-tuning job" + } + + FineTuningJobCheckpoint: + type: object + title: FineTuningJobCheckpoint + description: | + The `fine_tuning.job.checkpoint` object represents a model checkpoint for a fine-tuning job that is ready to use. + properties: + id: + type: string + description: The checkpoint identifier, which can be referenced in the API endpoints. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the checkpoint was created. + fine_tuned_model_checkpoint: + type: string + description: The name of the fine-tuned checkpoint model that is created. + step_number: + type: integer + description: The step number that the checkpoint was created at. + metrics: + type: object + description: Metrics at the step number during the fine-tuning job. + properties: + step: + type: number + train_loss: + type: number + train_mean_token_accuracy: + type: number + valid_loss: + type: number + valid_mean_token_accuracy: + type: number + full_valid_loss: + type: number + full_valid_mean_token_accuracy: + type: number + fine_tuning_job_id: + type: string + description: The name of the fine-tuning job that this checkpoint was created from. + object: + type: string + description: The object type, which is always "fine_tuning.job.checkpoint". + enum: [fine_tuning.job.checkpoint] + required: + - created_at + - fine_tuning_job_id + - fine_tuned_model_checkpoint + - id + - metrics + - object + - step_number + x-oaiMeta: + name: The fine-tuning job checkpoint object + example: | + { + "object": "fine_tuning.job.checkpoint", + "id": "ftckpt_qtZ5Gyk4BLq1SfLFWp3RtO3P", + "created_at": 1712211699, + "fine_tuned_model_checkpoint": "ft:gpt-4o-mini-2024-07-18:my-org:custom_suffix:9ABel2dg:ckpt-step-88", + "fine_tuning_job_id": "ftjob-fpbNQ3H1GrMehXRf8cO97xTN", + "metrics": { + "step": 88, + "train_loss": 0.478, + "train_mean_token_accuracy": 0.924, + "valid_loss": 10.112, + "valid_mean_token_accuracy": 0.145, + "full_valid_loss": 0.567, + "full_valid_mean_token_accuracy": 0.944 + }, + "step_number": 88 + } + + FinetuneChatRequestInput: + type: object + description: The per-line training example of a fine-tuning input file for chat models + properties: + messages: + type: array + minItems: 1 + items: + oneOf: + - $ref: "#/components/schemas/ChatCompletionRequestSystemMessage" + - $ref: "#/components/schemas/ChatCompletionRequestUserMessage" + - $ref: "#/components/schemas/FineTuneChatCompletionRequestAssistantMessage" + - $ref: "#/components/schemas/ChatCompletionRequestToolMessage" + - $ref: "#/components/schemas/ChatCompletionRequestFunctionMessage" + x-oaiExpandable: true + tools: + type: array + description: A list of tools the model may generate JSON inputs for. + items: + $ref: "#/components/schemas/ChatCompletionTool" + parallel_tool_calls: + $ref: "#/components/schemas/ParallelToolCalls" + functions: + deprecated: true + description: A list of functions the model may generate JSON inputs for. + type: array + minItems: 1 + maxItems: 128 + items: + $ref: "#/components/schemas/ChatCompletionFunctions" + x-oaiMeta: + name: Training format for chat models + example: | + { + "messages": [ + { "role": "user", "content": "What is the weather in San Francisco?" }, + { + "role": "assistant", + "tool_calls": [ + { + "id": "call_id", + "type": "function", + "function": { + "name": "get_current_weather", + "arguments": "{\"location\": \"San Francisco, USA\", \"format\": \"celsius\"}" + } + } + ] + } + ], + "parallel_tool_calls": false, + "tools": [ + { + "type": "function", + "function": { + "name": "get_current_weather", + "description": "Get the current weather", + "parameters": { + "type": "object", + "properties": { + "location": { + "type": "string", + "description": "The city and country, eg. San Francisco, USA" + }, + "format": { "type": "string", "enum": ["celsius", "fahrenheit"] } + }, + "required": ["location", "format"] + } + } + } + ] + } + + FinetuneCompletionRequestInput: + type: object + description: The per-line training example of a fine-tuning input file for completions models + properties: + prompt: + type: string + description: The input prompt for this training example. + completion: + type: string + description: The desired completion for this training example. + x-oaiMeta: + name: Training format for completions models + example: | + { + "prompt": "What is the answer to 2+2", + "completion": "4" + } + + CompletionUsage: + type: object + description: Usage statistics for the completion request. + properties: + completion_tokens: + type: integer + description: Number of tokens in the generated completion. + prompt_tokens: + type: integer + description: Number of tokens in the prompt. + total_tokens: + type: integer + description: Total number of tokens used in the request (prompt + completion). + required: + - prompt_tokens + - completion_tokens + - total_tokens + + RunCompletionUsage: + type: object + description: Usage statistics related to the run. This value will be `null` if the run is not in a terminal state (i.e. `in_progress`, `queued`, etc.). + properties: + completion_tokens: + type: integer + description: Number of completion tokens used over the course of the run. + prompt_tokens: + type: integer + description: Number of prompt tokens used over the course of the run. + total_tokens: + type: integer + description: Total number of tokens used (prompt + completion). + required: + - prompt_tokens + - completion_tokens + - total_tokens + nullable: true + + RunStepCompletionUsage: + type: object + description: Usage statistics related to the run step. This value will be `null` while the run step's status is `in_progress`. + properties: + completion_tokens: + type: integer + description: Number of completion tokens used over the course of the run step. + prompt_tokens: + type: integer + description: Number of prompt tokens used over the course of the run step. + total_tokens: + type: integer + description: Total number of tokens used (prompt + completion). + required: + - prompt_tokens + - completion_tokens + - total_tokens + nullable: true + + AssistantsApiResponseFormatOption: + description: | + Specifies the format that the model must output. Compatible with [GPT-4o](/docs/models/gpt-4o), [GPT-4 Turbo](/docs/models/gpt-4-turbo-and-gpt-4), and all GPT-3.5 Turbo models since `gpt-3.5-turbo-1106`. + + Setting to `{ "type": "json_schema", "json_schema": {...} }` enables Structured Outputs which guarantees the model will match your supplied JSON schema. Learn more in the [Structured Outputs guide](/docs/guides/structured-outputs). + + Setting to `{ "type": "json_object" }` enables JSON mode, which guarantees the message the model generates is valid JSON. + + **Important:** when using JSON mode, you **must** also instruct the model to produce JSON yourself via a system or user message. Without this, the model may generate an unending stream of whitespace until the generation reaches the token limit, resulting in a long-running and seemingly "stuck" request. Also note that the message content may be partially cut off if `finish_reason="length"`, which indicates the generation exceeded `max_tokens` or the conversation exceeded the max context length. + oneOf: + - type: string + description: > + `auto` is the default value + enum: [auto] + - $ref: "#/components/schemas/ResponseFormatText" + - $ref: "#/components/schemas/ResponseFormatJsonObject" + - $ref: "#/components/schemas/ResponseFormatJsonSchema" + x-oaiExpandable: true + + AssistantObject: + type: object + title: Assistant + description: Represents an `assistant` that can call the model and use tools. + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `assistant`. + type: string + enum: [assistant] + created_at: + description: The Unix timestamp (in seconds) for when the assistant was created. + type: integer + name: + description: &assistant_name_param_description | + The name of the assistant. The maximum length is 256 characters. + type: string + maxLength: 256 + nullable: true + description: + description: &assistant_description_param_description | + The description of the assistant. The maximum length is 512 characters. + type: string + maxLength: 512 + nullable: true + model: + description: *model_description + type: string + instructions: + description: &assistant_instructions_param_description | + The system instructions that the assistant uses. The maximum length is 256,000 characters. + type: string + maxLength: 256000 + nullable: true + tools: + description: &assistant_tools_param_description | + A list of tool enabled on the assistant. There can be a maximum of 128 tools per assistant. Tools can be of types `code_interpreter`, `file_search`, or `function`. + default: [] + type: array + maxItems: 128 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + tool_resources: + type: object + description: | + A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter`` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: string + nullable: true + metadata: + description: &metadata_description | + Set of 16 key-value pairs that can be attached to an object. This can be useful for storing additional information about the object in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. + type: object + x-oaiTypeLabel: map + nullable: true + temperature: + description: &run_temperature_description | + What sampling temperature to use, between 0 and 2. Higher values like 0.8 will make the output more random, while lower values like 0.2 will make it more focused and deterministic. + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: &run_top_p_description | + An alternative to sampling with temperature, called nucleus sampling, where the model considers the results of the tokens with top_p probability mass. So 0.1 means only the tokens comprising the top 10% probability mass are considered. + + We generally recommend altering this or temperature but not both. + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true + required: + - id + - object + - created_at + - name + - description + - model + - instructions + - tools + - metadata + x-oaiMeta: + name: The assistant object + beta: true + example: *create_assistants_example + + CreateAssistantRequest: + type: object + additionalProperties: false + properties: + model: + description: *model_description + example: "gpt-4o" + anyOf: + - type: string + - type: string + enum: + [ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ] + x-oaiTypeLabel: string + name: + description: *assistant_name_param_description + type: string + nullable: true + maxLength: 256 + description: + description: *assistant_description_param_description + type: string + nullable: true + maxLength: 512 + instructions: + description: *assistant_instructions_param_description + type: string + nullable: true + maxLength: 256000 + tools: + description: *assistant_tools_param_description + default: [] + type: array + maxItems: 128 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + tool_resources: + type: object + description: | + A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: string + vector_stores: + type: array + description: | + A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store. + maxItems: 10000 + items: + type: string + chunking_strategy: + # Ideally we'd reuse the chunking strategy schema here, but it doesn't expand properly + type: object + description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. + oneOf: + - type: object + title: Auto Chunking Strategy + description: The default strategy. This strategy currently uses a `max_chunk_size_tokens` of `800` and `chunk_overlap_tokens` of `400`. + additionalProperties: false + properties: + type: + type: string + description: Always `auto`. + enum: ["auto"] + required: + - type + - type: object + title: Static Chunking Strategy + additionalProperties: false + properties: + type: + type: string + description: Always `static`. + enum: ["static"] + static: + type: object + additionalProperties: false + properties: + max_chunk_size_tokens: + type: integer + minimum: 100 + maximum: 4096 + description: The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`. + chunk_overlap_tokens: + type: integer + description: | + The number of tokens that overlap between chunks. The default value is `400`. + + Note that the overlap must not exceed half of `max_chunk_size_tokens`. + required: + - max_chunk_size_tokens + - chunk_overlap_tokens + required: + - type + - static + x-oaiExpandable: true + metadata: + type: object + description: | + Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. + x-oaiTypeLabel: map + oneOf: + - required: [vector_store_ids] + - required: [vector_stores] + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + temperature: + description: *run_temperature_description + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: *run_top_p_description + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true + required: + - model + + ModifyAssistantRequest: + type: object + additionalProperties: false + properties: + model: + description: *model_description + anyOf: + - type: string + name: + description: *assistant_name_param_description + type: string + nullable: true + maxLength: 256 + description: + description: *assistant_description_param_description + type: string + nullable: true + maxLength: 512 + instructions: + description: *assistant_instructions_param_description + type: string + nullable: true + maxLength: 256000 + tools: + description: *assistant_tools_param_description + default: [] + type: array + maxItems: 128 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + tool_resources: + type: object + description: | + A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + Overrides the list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + Overrides the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: string + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + temperature: + description: *run_temperature_description + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: *run_top_p_description + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true + + DeleteAssistantResponse: + type: object + properties: + id: + type: string + deleted: + type: boolean + object: + type: string + enum: [assistant.deleted] + required: + - id + - object + - deleted + + ListAssistantsResponse: + type: object + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/AssistantObject" + first_id: + type: string + example: "asst_abc123" + last_id: + type: string + example: "asst_abc456" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + x-oaiMeta: + name: List assistants response object + group: chat + example: *list_assistants_example + + AssistantToolsCode: + type: object + title: Code interpreter tool + properties: + type: + type: string + description: "The type of tool being defined: `code_interpreter`" + enum: ["code_interpreter"] + required: + - type + + AssistantToolsFileSearch: + type: object + title: FileSearch tool + properties: + type: + type: string + description: "The type of tool being defined: `file_search`" + enum: ["file_search"] + file_search: + type: object + description: Overrides for the file search tool. + properties: + max_num_results: + type: integer + minimum: 1 + maximum: 50 + description: | + The maximum number of results the file search tool should output. The default is 20 for `gpt-4*` models and 5 for `gpt-3.5-turbo`. This number should be between 1 and 50 inclusive. + + Note that the file search tool may output fewer than `max_num_results` results. See the [file search tool documentation](/docs/assistants/tools/file-search/number-of-chunks-returned) for more information. + required: + - type + + AssistantToolsFileSearchTypeOnly: + type: object + title: FileSearch tool + properties: + type: + type: string + description: "The type of tool being defined: `file_search`" + enum: ["file_search"] + required: + - type + + AssistantToolsFunction: + type: object + title: Function tool + properties: + type: + type: string + description: "The type of tool being defined: `function`" + enum: ["function"] + function: + $ref: "#/components/schemas/FunctionObject" + required: + - type + - function + + TruncationObject: + type: object + title: Thread Truncation Controls + description: Controls for how a thread will be truncated prior to the run. Use this to control the intial context window of the run. + properties: + type: + type: string + description: The truncation strategy to use for the thread. The default is `auto`. If set to `last_messages`, the thread will be truncated to the n most recent messages in the thread. When set to `auto`, messages in the middle of the thread will be dropped to fit the context length of the model, `max_prompt_tokens`. + enum: ["auto", "last_messages"] + last_messages: + type: integer + description: The number of most recent messages from the thread when constructing the context for the run. + minimum: 1 + nullable: true + required: + - type + + AssistantsApiToolChoiceOption: + description: | + Controls which (if any) tool is called by the model. + `none` means the model will not call any tools and instead generates a message. + `auto` is the default value and means the model can pick between generating a message or calling one or more tools. + `required` means the model must call one or more tools before responding to the user. + Specifying a particular tool like `{"type": "file_search"}` or `{"type": "function", "function": {"name": "my_function"}}` forces the model to call that tool. + + oneOf: + - type: string + description: > + `none` means the model will not call any tools and instead generates a message. + `auto` means the model can pick between generating a message or calling one or more tools. + `required` means the model must call one or more tools before responding to the user. + enum: [none, auto, required] + - $ref: "#/components/schemas/AssistantsNamedToolChoice" + x-oaiExpandable: true + + AssistantsNamedToolChoice: + type: object + description: Specifies a tool the model should use. Use to force the model to call a specific tool. + properties: + type: + type: string + enum: ["function", "code_interpreter", "file_search"] + description: The type of the tool. If type is `function`, the function name must be set + function: + type: object + properties: + name: + type: string + description: The name of the function to call. + required: + - name + required: + - type + + RunObject: + type: object + title: A run on a thread + description: Represents an execution run on a [thread](/docs/api-reference/threads). + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread.run`. + type: string + enum: ["thread.run"] + created_at: + description: The Unix timestamp (in seconds) for when the run was created. + type: integer + thread_id: + description: The ID of the [thread](/docs/api-reference/threads) that was executed on as a part of this run. + type: string + assistant_id: + description: The ID of the [assistant](/docs/api-reference/assistants) used for execution of this run. + type: string + status: + description: The status of the run, which can be either `queued`, `in_progress`, `requires_action`, `cancelling`, `cancelled`, `failed`, `completed`, `incomplete`, or `expired`. + type: string + enum: + [ + "queued", + "in_progress", + "requires_action", + "cancelling", + "cancelled", + "failed", + "completed", + "incomplete", + "expired", + ] + required_action: + type: object + description: Details on the action required to continue the run. Will be `null` if no action is required. + nullable: true + properties: + type: + description: For now, this is always `submit_tool_outputs`. + type: string + enum: ["submit_tool_outputs"] + submit_tool_outputs: + type: object + description: Details on the tool outputs needed for this run to continue. + properties: + tool_calls: + type: array + description: A list of the relevant tool calls. + items: + $ref: "#/components/schemas/RunToolCallObject" + required: + - tool_calls + required: + - type + - submit_tool_outputs + last_error: + type: object + description: The last error associated with this run. Will be `null` if there are no errors. + nullable: true + properties: + code: + type: string + description: One of `server_error`, `rate_limit_exceeded`, or `invalid_prompt`. + enum: ["server_error", "rate_limit_exceeded", "invalid_prompt"] + message: + type: string + description: A human-readable description of the error. + required: + - code + - message + expires_at: + description: The Unix timestamp (in seconds) for when the run will expire. + type: integer + nullable: true + started_at: + description: The Unix timestamp (in seconds) for when the run was started. + type: integer + nullable: true + cancelled_at: + description: The Unix timestamp (in seconds) for when the run was cancelled. + type: integer + nullable: true + failed_at: + description: The Unix timestamp (in seconds) for when the run failed. + type: integer + nullable: true + completed_at: + description: The Unix timestamp (in seconds) for when the run was completed. + type: integer + nullable: true + incomplete_details: + description: Details on why the run is incomplete. Will be `null` if the run is not incomplete. + type: object + nullable: true + properties: + reason: + description: The reason why the run is incomplete. This will point to which specific token limit was reached over the course of the run. + type: string + enum: ["max_completion_tokens", "max_prompt_tokens"] + model: + description: The model that the [assistant](/docs/api-reference/assistants) used for this run. + type: string + instructions: + description: The instructions that the [assistant](/docs/api-reference/assistants) used for this run. + type: string + tools: + description: The list of tools that the [assistant](/docs/api-reference/assistants) used for this run. + default: [] + type: array + maxItems: 20 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + usage: + $ref: "#/components/schemas/RunCompletionUsage" + temperature: + description: The sampling temperature used for this run. If not set, defaults to 1. + type: number + nullable: true + top_p: + description: The nucleus sampling value used for this run. If not set, defaults to 1. + type: number + nullable: true + max_prompt_tokens: + type: integer + nullable: true + description: | + The maximum number of prompt tokens specified to have been used over the course of the run. + minimum: 256 + max_completion_tokens: + type: integer + nullable: true + description: | + The maximum number of completion tokens specified to have been used over the course of the run. + minimum: 256 + truncation_strategy: + $ref: "#/components/schemas/TruncationObject" + nullable: true + tool_choice: + $ref: "#/components/schemas/AssistantsApiToolChoiceOption" + nullable: true + parallel_tool_calls: + $ref: "#/components/schemas/ParallelToolCalls" + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true + required: + - id + - object + - created_at + - thread_id + - assistant_id + - status + - required_action + - last_error + - expires_at + - started_at + - cancelled_at + - failed_at + - completed_at + - model + - instructions + - tools + - metadata + - usage + - incomplete_details + - max_prompt_tokens + - max_completion_tokens + - truncation_strategy + - tool_choice + - parallel_tool_calls + - response_format + x-oaiMeta: + name: The run object + beta: true + example: | + { + "id": "run_abc123", + "object": "thread.run", + "created_at": 1698107661, + "assistant_id": "asst_abc123", + "thread_id": "thread_abc123", + "status": "completed", + "started_at": 1699073476, + "expires_at": null, + "cancelled_at": null, + "failed_at": null, + "completed_at": 1699073498, + "last_error": null, + "model": "gpt-4o", + "instructions": null, + "tools": [{"type": "file_search"}, {"type": "code_interpreter"}], + "metadata": {}, + "incomplete_details": null, + "usage": { + "prompt_tokens": 123, + "completion_tokens": 456, + "total_tokens": 579 + }, + "temperature": 1.0, + "top_p": 1.0, + "max_prompt_tokens": 1000, + "max_completion_tokens": 1000, + "truncation_strategy": { + "type": "auto", + "last_messages": null + }, + "response_format": "auto", + "tool_choice": "auto", + "parallel_tool_calls": true + } + CreateRunRequest: + type: object + additionalProperties: false + properties: + assistant_id: + description: The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run. + type: string + model: + description: The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used. + example: "gpt-4o" + anyOf: + - type: string + - type: string + enum: + [ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ] + x-oaiTypeLabel: string + nullable: true + instructions: + description: Overrides the [instructions](/docs/api-reference/assistants/createAssistant) of the assistant. This is useful for modifying the behavior on a per-run basis. + type: string + nullable: true + additional_instructions: + description: Appends additional instructions at the end of the instructions for the run. This is useful for modifying the behavior on a per-run basis without overriding other instructions. + type: string + nullable: true + additional_messages: + description: Adds additional messages to the thread before creating the run. + type: array + items: + $ref: "#/components/schemas/CreateMessageRequest" + nullable: true + tools: + description: Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis. + nullable: true + type: array + maxItems: 20 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + x-oaiExpandable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + temperature: + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + description: *run_temperature_description + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: *run_top_p_description + stream: + type: boolean + nullable: true + description: | + If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. + max_prompt_tokens: + type: integer + nullable: true + description: | + The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + minimum: 256 + max_completion_tokens: + type: integer + nullable: true + description: | + The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + minimum: 256 + truncation_strategy: + $ref: "#/components/schemas/TruncationObject" + nullable: true + tool_choice: + $ref: "#/components/schemas/AssistantsApiToolChoiceOption" + nullable: true + parallel_tool_calls: + $ref: "#/components/schemas/ParallelToolCalls" + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true + required: + - thread_id + - assistant_id + ListRunsResponse: + type: object + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/RunObject" + first_id: + type: string + example: "run_abc123" + last_id: + type: string + example: "run_abc456" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + ModifyRunRequest: + type: object + additionalProperties: false + properties: + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + SubmitToolOutputsRunRequest: + type: object + additionalProperties: false + properties: + tool_outputs: + description: A list of tools for which the outputs are being submitted. + type: array + items: + type: object + properties: + tool_call_id: + type: string + description: The ID of the tool call in the `required_action` object within the run object the output is being submitted for. + output: + type: string + description: The output of the tool call to be submitted to continue the run. + stream: + type: boolean + nullable: true + description: | + If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. + required: + - tool_outputs + + RunToolCallObject: + type: object + description: Tool call objects + properties: + id: + type: string + description: The ID of the tool call. This ID must be referenced when you submit the tool outputs in using the [Submit tool outputs to run](/docs/api-reference/runs/submitToolOutputs) endpoint. + type: + type: string + description: The type of tool call the output is required for. For now, this is always `function`. + enum: ["function"] + function: + type: object + description: The function definition. + properties: + name: + type: string + description: The name of the function. + arguments: + type: string + description: The arguments that the model expects you to pass to the function. + required: + - name + - arguments + required: + - id + - type + - function + + CreateThreadAndRunRequest: + type: object + additionalProperties: false + properties: + assistant_id: + description: The ID of the [assistant](/docs/api-reference/assistants) to use to execute this run. + type: string + thread: + $ref: "#/components/schemas/CreateThreadRequest" + description: If no thread is provided, an empty thread will be created. + model: + description: The ID of the [Model](/docs/api-reference/models) to be used to execute this run. If a value is provided here, it will override the model associated with the assistant. If not, the model associated with the assistant will be used. + example: "gpt-4o" + anyOf: + - type: string + - type: string + enum: + [ + "gpt-4o", + "gpt-4o-2024-08-06", + "gpt-4o-2024-05-13", + "gpt-4o-2024-08-06", + "gpt-4o-mini", + "gpt-4o-mini-2024-07-18", + "gpt-4-turbo", + "gpt-4-turbo-2024-04-09", + "gpt-4-0125-preview", + "gpt-4-turbo-preview", + "gpt-4-1106-preview", + "gpt-4-vision-preview", + "gpt-4", + "gpt-4-0314", + "gpt-4-0613", + "gpt-4-32k", + "gpt-4-32k-0314", + "gpt-4-32k-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k", + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo-1106", + "gpt-3.5-turbo-0125", + "gpt-3.5-turbo-16k-0613", + ] + x-oaiTypeLabel: string + nullable: true + instructions: + description: Override the default system message of the assistant. This is useful for modifying the behavior on a per-run basis. + type: string + nullable: true + tools: + description: Override the tools the assistant can use for this run. This is useful for modifying the behavior on a per-run basis. + nullable: true + type: array + maxItems: 20 + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearch" + - $ref: "#/components/schemas/AssistantToolsFunction" + tool_resources: + type: object + description: | + A set of resources that are used by the assistant's tools. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The ID of the [vector store](/docs/api-reference/vector-stores/object) attached to this assistant. There can be a maximum of 1 vector store attached to the assistant. + maxItems: 1 + items: + type: string + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + temperature: + type: number + minimum: 0 + maximum: 2 + default: 1 + example: 1 + nullable: true + description: *run_temperature_description + top_p: + type: number + minimum: 0 + maximum: 1 + default: 1 + example: 1 + nullable: true + description: *run_top_p_description + stream: + type: boolean + nullable: true + description: | + If `true`, returns a stream of events that happen during the Run as server-sent events, terminating when the Run enters a terminal state with a `data: [DONE]` message. + max_prompt_tokens: + type: integer + nullable: true + description: | + The maximum number of prompt tokens that may be used over the course of the run. The run will make a best effort to use only the number of prompt tokens specified, across multiple turns of the run. If the run exceeds the number of prompt tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + minimum: 256 + max_completion_tokens: + type: integer + nullable: true + description: | + The maximum number of completion tokens that may be used over the course of the run. The run will make a best effort to use only the number of completion tokens specified, across multiple turns of the run. If the run exceeds the number of completion tokens specified, the run will end with status `incomplete`. See `incomplete_details` for more info. + minimum: 256 + truncation_strategy: + $ref: "#/components/schemas/TruncationObject" + nullable: true + tool_choice: + $ref: "#/components/schemas/AssistantsApiToolChoiceOption" + nullable: true + parallel_tool_calls: + $ref: "#/components/schemas/ParallelToolCalls" + response_format: + $ref: "#/components/schemas/AssistantsApiResponseFormatOption" + nullable: true + required: + - thread_id + - assistant_id + + ThreadObject: + type: object + title: Thread + description: Represents a thread that contains [messages](/docs/api-reference/messages). + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread`. + type: string + enum: ["thread"] + created_at: + description: The Unix timestamp (in seconds) for when the thread was created. + type: integer + tool_resources: + type: object + description: | + A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread. + maxItems: 1 + items: + type: string + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + required: + - id + - object + - created_at + - tool_resources + - metadata + x-oaiMeta: + name: The thread object + beta: true + example: | + { + "id": "thread_abc123", + "object": "thread", + "created_at": 1698107661, + "metadata": {} + } + + CreateThreadRequest: + type: object + additionalProperties: false + properties: + messages: + description: A list of [messages](/docs/api-reference/messages) to start the thread with. + type: array + items: + $ref: "#/components/schemas/CreateMessageRequest" + tool_resources: + type: object + description: | + A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread. + maxItems: 1 + items: + type: string + vector_stores: + type: array + description: | + A helper to create a [vector store](/docs/api-reference/vector-stores/object) with file_ids and attach it to this thread. There can be a maximum of 1 vector store attached to the thread. + maxItems: 1 + items: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs to add to the vector store. There can be a maximum of 10000 files in a vector store. + maxItems: 10000 + items: + type: string + chunking_strategy: + # Ideally we'd reuse the chunking strategy schema here, but it doesn't expand properly + type: object + description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. + oneOf: + - type: object + title: Auto Chunking Strategy + description: The default strategy. This strategy currently uses a `max_chunk_size_tokens` of `800` and `chunk_overlap_tokens` of `400`. + additionalProperties: false + properties: + type: + type: string + description: Always `auto`. + enum: ["auto"] + required: + - type + - type: object + title: Static Chunking Strategy + additionalProperties: false + properties: + type: + type: string + description: Always `static`. + enum: ["static"] + static: + type: object + additionalProperties: false + properties: + max_chunk_size_tokens: + type: integer + minimum: 100 + maximum: 4096 + description: The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`. + chunk_overlap_tokens: + type: integer + description: | + The number of tokens that overlap between chunks. The default value is `400`. + + Note that the overlap must not exceed half of `max_chunk_size_tokens`. + required: + - max_chunk_size_tokens + - chunk_overlap_tokens + required: + - type + - static + x-oaiExpandable: true + metadata: + type: object + description: | + Set of 16 key-value pairs that can be attached to a vector store. This can be useful for storing additional information about the vector store in a structured format. Keys can be a maximum of 64 characters long and values can be a maxium of 512 characters long. + x-oaiTypeLabel: map + x-oaiExpandable: true + oneOf: + - required: [vector_store_ids] + - required: [vector_stores] + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + + ModifyThreadRequest: + type: object + additionalProperties: false + properties: + tool_resources: + type: object + description: | + A set of resources that are made available to the assistant's tools in this thread. The resources are specific to the type of tool. For example, the `code_interpreter` tool requires a list of file IDs, while the `file_search` tool requires a list of vector store IDs. + properties: + code_interpreter: + type: object + properties: + file_ids: + type: array + description: | + A list of [file](/docs/api-reference/files) IDs made available to the `code_interpreter` tool. There can be a maximum of 20 files associated with the tool. + default: [] + maxItems: 20 + items: + type: string + file_search: + type: object + properties: + vector_store_ids: + type: array + description: | + The [vector store](/docs/api-reference/vector-stores/object) attached to this thread. There can be a maximum of 1 vector store attached to the thread. + maxItems: 1 + items: + type: string + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + + DeleteThreadResponse: + type: object + properties: + id: + type: string + deleted: + type: boolean + object: + type: string + enum: [thread.deleted] + required: + - id + - object + - deleted + + ListThreadsResponse: + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/ThreadObject" + first_id: + type: string + example: "asst_abc123" + last_id: + type: string + example: "asst_abc456" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + + MessageObject: + type: object + title: The message object + description: Represents a message within a [thread](/docs/api-reference/threads). + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread.message`. + type: string + enum: ["thread.message"] + created_at: + description: The Unix timestamp (in seconds) for when the message was created. + type: integer + thread_id: + description: The [thread](/docs/api-reference/threads) ID that this message belongs to. + type: string + status: + description: The status of the message, which can be either `in_progress`, `incomplete`, or `completed`. + type: string + enum: ["in_progress", "incomplete", "completed"] + incomplete_details: + description: On an incomplete message, details about why the message is incomplete. + type: object + properties: + reason: + type: string + description: The reason the message is incomplete. + enum: + [ + "content_filter", + "max_tokens", + "run_cancelled", + "run_expired", + "run_failed", + ] + nullable: true + required: + - reason + completed_at: + description: The Unix timestamp (in seconds) for when the message was completed. + type: integer + nullable: true + incomplete_at: + description: The Unix timestamp (in seconds) for when the message was marked as incomplete. + type: integer + nullable: true + role: + description: The entity that produced the message. One of `user` or `assistant`. + type: string + enum: ["user", "assistant"] + content: + description: The content of the message in array of text and/or images. + type: array + items: + oneOf: + - $ref: "#/components/schemas/MessageContentImageFileObject" + - $ref: "#/components/schemas/MessageContentImageUrlObject" + - $ref: "#/components/schemas/MessageContentTextObject" + - $ref: "#/components/schemas/MessageContentRefusalObject" + x-oaiExpandable: true + assistant_id: + description: If applicable, the ID of the [assistant](/docs/api-reference/assistants) that authored this message. + type: string + nullable: true + run_id: + description: The ID of the [run](/docs/api-reference/runs) associated with the creation of this message. Value is `null` when messages are created manually using the create message or create thread endpoints. + type: string + nullable: true + attachments: + type: array + items: + type: object + properties: + file_id: + type: string + description: The ID of the file to attach to the message. + tools: + description: The tools to add this file to. + type: array + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearchTypeOnly" + x-oaiExpandable: true + description: A list of files attached to the message, and the tools they were added to. + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + required: + - id + - object + - created_at + - thread_id + - status + - incomplete_details + - completed_at + - incomplete_at + - role + - content + - assistant_id + - run_id + - attachments + - metadata + x-oaiMeta: + name: The message object + beta: true + example: | + { + "id": "msg_abc123", + "object": "thread.message", + "created_at": 1698983503, + "thread_id": "thread_abc123", + "role": "assistant", + "content": [ + { + "type": "text", + "text": { + "value": "Hi! How can I help you today?", + "annotations": [] + } + } + ], + "assistant_id": "asst_abc123", + "run_id": "run_abc123", + "attachments": [], + "metadata": {} + } + + MessageDeltaObject: + type: object + title: Message delta object + description: | + Represents a message delta i.e. any changed fields on a message during streaming. + properties: + id: + description: The identifier of the message, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread.message.delta`. + type: string + enum: ["thread.message.delta"] + delta: + description: The delta containing the fields that have changed on the Message. + type: object + properties: + role: + description: The entity that produced the message. One of `user` or `assistant`. + type: string + enum: ["user", "assistant"] + content: + description: The content of the message in array of text and/or images. + type: array + items: + oneOf: + - $ref: "#/components/schemas/MessageDeltaContentImageFileObject" + - $ref: "#/components/schemas/MessageDeltaContentTextObject" + - $ref: "#/components/schemas/MessageDeltaContentRefusalObject" + - $ref: "#/components/schemas/MessageDeltaContentImageUrlObject" + x-oaiExpandable: true + required: + - id + - object + - delta + x-oaiMeta: + name: The message delta object + beta: true + example: | + { + "id": "msg_123", + "object": "thread.message.delta", + "delta": { + "content": [ + { + "index": 0, + "type": "text", + "text": { "value": "Hello", "annotations": [] } + } + ] + } + } + + CreateMessageRequest: + type: object + additionalProperties: false + required: + - role + - content + properties: + role: + type: string + enum: ["user", "assistant"] + description: | + The role of the entity that is creating the message. Allowed values include: + - `user`: Indicates the message is sent by an actual user and should be used in most cases to represent user-generated messages. + - `assistant`: Indicates the message is generated by the assistant. Use this value to insert messages from the assistant into the conversation. + content: + oneOf: + - type: string + description: The text contents of the message. + title: Text content + - type: array + description: An array of content parts with a defined type, each can be of type `text` or images can be passed with `image_url` or `image_file`. Image types are only supported on [Vision-compatible models](/docs/models/overview). + title: Array of content parts + items: + oneOf: + - $ref: "#/components/schemas/MessageContentImageFileObject" + - $ref: "#/components/schemas/MessageContentImageUrlObject" + - $ref: "#/components/schemas/MessageRequestContentTextObject" + x-oaiExpandable: true + minItems: 1 + x-oaiExpandable: true + attachments: + type: array + items: + type: object + properties: + file_id: + type: string + description: The ID of the file to attach to the message. + tools: + description: The tools to add this file to. + type: array + items: + oneOf: + - $ref: "#/components/schemas/AssistantToolsCode" + - $ref: "#/components/schemas/AssistantToolsFileSearchTypeOnly" + x-oaiExpandable: true + description: A list of files attached to the message, and the tools they should be added to. + required: + - file_id + - tools + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + + ModifyMessageRequest: + type: object + additionalProperties: false + properties: + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + + DeleteMessageResponse: + type: object + properties: + id: + type: string + deleted: + type: boolean + object: + type: string + enum: [thread.message.deleted] + required: + - id + - object + - deleted + + ListMessagesResponse: + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/MessageObject" + first_id: + type: string + example: "msg_abc123" + last_id: + type: string + example: "msg_abc123" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + + MessageContentImageFileObject: + title: Image file + type: object + description: References an image [File](/docs/api-reference/files) in the content of a message. + properties: + type: + description: Always `image_file`. + type: string + enum: ["image_file"] + image_file: + type: object + properties: + file_id: + description: The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content. + type: string + detail: + type: string + description: Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`. + enum: ["auto", "low", "high"] + default: "auto" + required: + - file_id + required: + - type + - image_file + + MessageDeltaContentImageFileObject: + title: Image file + type: object + description: References an image [File](/docs/api-reference/files) in the content of a message. + properties: + index: + type: integer + description: The index of the content part in the message. + type: + description: Always `image_file`. + type: string + enum: ["image_file"] + image_file: + type: object + properties: + file_id: + description: The [File](/docs/api-reference/files) ID of the image in the message content. Set `purpose="vision"` when uploading the File if you need to later display the file content. + type: string + detail: + type: string + description: Specifies the detail level of the image if specified by the user. `low` uses fewer tokens, you can opt in to high resolution using `high`. + enum: ["auto", "low", "high"] + default: "auto" + required: + - index + - type + + MessageContentImageUrlObject: + title: Image URL + type: object + description: References an image URL in the content of a message. + properties: + type: + type: string + enum: ["image_url"] + description: The type of the content part. + image_url: + type: object + properties: + url: + type: string + description: "The external URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." + format: uri + detail: + type: string + description: Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. Default value is `auto` + enum: ["auto", "low", "high"] + default: "auto" + required: + - url + required: + - type + - image_url + + MessageDeltaContentImageUrlObject: + title: Image URL + type: object + description: References an image URL in the content of a message. + properties: + index: + type: integer + description: The index of the content part in the message. + type: + description: Always `image_url`. + type: string + enum: ["image_url"] + image_url: + type: object + properties: + url: + description: "The URL of the image, must be a supported image types: jpeg, jpg, png, gif, webp." + type: string + detail: + type: string + description: Specifies the detail level of the image. `low` uses fewer tokens, you can opt in to high resolution using `high`. + enum: ["auto", "low", "high"] + default: "auto" + required: + - index + - type + + MessageContentTextObject: + title: Text + type: object + description: The text content that is part of a message. + properties: + type: + description: Always `text`. + type: string + enum: ["text"] + text: + type: object + properties: + value: + description: The data that makes up the text. + type: string + annotations: + type: array + items: + oneOf: + - $ref: "#/components/schemas/MessageContentTextAnnotationsFileCitationObject" + - $ref: "#/components/schemas/MessageContentTextAnnotationsFilePathObject" + x-oaiExpandable: true + required: + - value + - annotations + required: + - type + - text + + MessageContentRefusalObject: + title: Refusal + type: object + description: The refusal content generated by the assistant. + properties: + type: + description: Always `refusal`. + type: string + enum: ["refusal"] + refusal: + type: string + nullable: false + required: + - type + - refusal + + MessageRequestContentTextObject: + title: Text + type: object + description: The text content that is part of a message. + properties: + type: + description: Always `text`. + type: string + enum: ["text"] + text: + type: string + description: Text content to be sent to the model + required: + - type + - text + + MessageContentTextAnnotationsFileCitationObject: + title: File citation + type: object + description: A citation within the message that points to a specific quote from a specific File associated with the assistant or the message. Generated when the assistant uses the "file_search" tool to search files. + properties: + type: + description: Always `file_citation`. + type: string + enum: ["file_citation"] + text: + description: The text in the message content that needs to be replaced. + type: string + file_citation: + type: object + properties: + file_id: + description: The ID of the specific File the citation is from. + type: string + required: + - file_id + start_index: + type: integer + minimum: 0 + end_index: + type: integer + minimum: 0 + required: + - type + - text + - file_citation + - start_index + - end_index + + MessageContentTextAnnotationsFilePathObject: + title: File path + type: object + description: A URL for the file that's generated when the assistant used the `code_interpreter` tool to generate a file. + properties: + type: + description: Always `file_path`. + type: string + enum: ["file_path"] + text: + description: The text in the message content that needs to be replaced. + type: string + file_path: + type: object + properties: + file_id: + description: The ID of the file that was generated. + type: string + required: + - file_id + start_index: + type: integer + minimum: 0 + end_index: + type: integer + minimum: 0 + required: + - type + - text + - file_path + - start_index + - end_index + + MessageDeltaContentTextObject: + title: Text + type: object + description: The text content that is part of a message. + properties: + index: + type: integer + description: The index of the content part in the message. + type: + description: Always `text`. + type: string + enum: ["text"] + text: + type: object + properties: + value: + description: The data that makes up the text. + type: string + annotations: + type: array + items: + oneOf: + - $ref: "#/components/schemas/MessageDeltaContentTextAnnotationsFileCitationObject" + - $ref: "#/components/schemas/MessageDeltaContentTextAnnotationsFilePathObject" + x-oaiExpandable: true + required: + - index + - type + + MessageDeltaContentRefusalObject: + title: Refusal + type: object + description: The refusal content that is part of a message. + properties: + index: + type: integer + description: The index of the refusal part in the message. + type: + description: Always `refusal`. + type: string + enum: ["refusal"] + refusal: + type: string + required: + - index + - type + + MessageDeltaContentTextAnnotationsFileCitationObject: + title: File citation + type: object + description: A citation within the message that points to a specific quote from a specific File associated with the assistant or the message. Generated when the assistant uses the "file_search" tool to search files. + properties: + index: + type: integer + description: The index of the annotation in the text content part. + type: + description: Always `file_citation`. + type: string + enum: ["file_citation"] + text: + description: The text in the message content that needs to be replaced. + type: string + file_citation: + type: object + properties: + file_id: + description: The ID of the specific File the citation is from. + type: string + quote: + description: The specific quote in the file. + type: string + start_index: + type: integer + minimum: 0 + end_index: + type: integer + minimum: 0 + required: + - index + - type + + MessageDeltaContentTextAnnotationsFilePathObject: + title: File path + type: object + description: A URL for the file that's generated when the assistant used the `code_interpreter` tool to generate a file. + properties: + index: + type: integer + description: The index of the annotation in the text content part. + type: + description: Always `file_path`. + type: string + enum: ["file_path"] + text: + description: The text in the message content that needs to be replaced. + type: string + file_path: + type: object + properties: + file_id: + description: The ID of the file that was generated. + type: string + start_index: + type: integer + minimum: 0 + end_index: + type: integer + minimum: 0 + required: + - index + - type + + RunStepObject: + type: object + title: Run steps + description: | + Represents a step in execution of a run. + properties: + id: + description: The identifier of the run step, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread.run.step`. + type: string + enum: ["thread.run.step"] + created_at: + description: The Unix timestamp (in seconds) for when the run step was created. + type: integer + assistant_id: + description: The ID of the [assistant](/docs/api-reference/assistants) associated with the run step. + type: string + thread_id: + description: The ID of the [thread](/docs/api-reference/threads) that was run. + type: string + run_id: + description: The ID of the [run](/docs/api-reference/runs) that this run step is a part of. + type: string + type: + description: The type of run step, which can be either `message_creation` or `tool_calls`. + type: string + enum: ["message_creation", "tool_calls"] + status: + description: The status of the run step, which can be either `in_progress`, `cancelled`, `failed`, `completed`, or `expired`. + type: string + enum: ["in_progress", "cancelled", "failed", "completed", "expired"] + step_details: + type: object + description: The details of the run step. + oneOf: + - $ref: "#/components/schemas/RunStepDetailsMessageCreationObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsObject" + x-oaiExpandable: true + last_error: + type: object + description: The last error associated with this run step. Will be `null` if there are no errors. + nullable: true + properties: + code: + type: string + description: One of `server_error` or `rate_limit_exceeded`. + enum: ["server_error", "rate_limit_exceeded"] + message: + type: string + description: A human-readable description of the error. + required: + - code + - message + expired_at: + description: The Unix timestamp (in seconds) for when the run step expired. A step is considered expired if the parent run is expired. + type: integer + nullable: true + cancelled_at: + description: The Unix timestamp (in seconds) for when the run step was cancelled. + type: integer + nullable: true + failed_at: + description: The Unix timestamp (in seconds) for when the run step failed. + type: integer + nullable: true + completed_at: + description: The Unix timestamp (in seconds) for when the run step completed. + type: integer + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + usage: + $ref: "#/components/schemas/RunStepCompletionUsage" + required: + - id + - object + - created_at + - assistant_id + - thread_id + - run_id + - type + - status + - step_details + - last_error + - expired_at + - cancelled_at + - failed_at + - completed_at + - metadata + - usage + x-oaiMeta: + name: The run step object + beta: true + example: *run_step_object_example + + RunStepDeltaObject: + type: object + title: Run step delta object + description: | + Represents a run step delta i.e. any changed fields on a run step during streaming. + properties: + id: + description: The identifier of the run step, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `thread.run.step.delta`. + type: string + enum: ["thread.run.step.delta"] + delta: + description: The delta containing the fields that have changed on the run step. + type: object + properties: + step_details: + type: object + description: The details of the run step. + oneOf: + - $ref: "#/components/schemas/RunStepDeltaStepDetailsMessageCreationObject" + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsObject" + x-oaiExpandable: true + required: + - id + - object + - delta + x-oaiMeta: + name: The run step delta object + beta: true + example: | + { + "id": "step_123", + "object": "thread.run.step.delta", + "delta": { + "step_details": { + "type": "tool_calls", + "tool_calls": [ + { + "index": 0, + "id": "call_123", + "type": "code_interpreter", + "code_interpreter": { "input": "", "outputs": [] } + } + ] + } + } + } + + ListRunStepsResponse: + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/RunStepObject" + first_id: + type: string + example: "step_abc123" + last_id: + type: string + example: "step_abc456" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + + RunStepDetailsMessageCreationObject: + title: Message creation + type: object + description: Details of the message creation by the run step. + properties: + type: + description: Always `message_creation`. + type: string + enum: ["message_creation"] + message_creation: + type: object + properties: + message_id: + type: string + description: The ID of the message that was created by this run step. + required: + - message_id + required: + - type + - message_creation + + RunStepDeltaStepDetailsMessageCreationObject: + title: Message creation + type: object + description: Details of the message creation by the run step. + properties: + type: + description: Always `message_creation`. + type: string + enum: ["message_creation"] + message_creation: + type: object + properties: + message_id: + type: string + description: The ID of the message that was created by this run step. + required: + - type + + RunStepDetailsToolCallsObject: + title: Tool calls + type: object + description: Details of the tool call. + properties: + type: + description: Always `tool_calls`. + type: string + enum: ["tool_calls"] + tool_calls: + type: array + description: | + An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`. + items: + oneOf: + - $ref: "#/components/schemas/RunStepDetailsToolCallsCodeObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsFileSearchObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsFunctionObject" + x-oaiExpandable: true + required: + - type + - tool_calls + + RunStepDeltaStepDetailsToolCallsObject: + title: Tool calls + type: object + description: Details of the tool call. + properties: + type: + description: Always `tool_calls`. + type: string + enum: ["tool_calls"] + tool_calls: + type: array + description: | + An array of tool calls the run step was involved in. These can be associated with one of three types of tools: `code_interpreter`, `file_search`, or `function`. + items: + oneOf: + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsCodeObject" + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsFileSearchObject" + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsFunctionObject" + x-oaiExpandable: true + required: + - type + + RunStepDetailsToolCallsCodeObject: + title: Code Interpreter tool call + type: object + description: Details of the Code Interpreter tool call the run step was involved in. + properties: + id: + type: string + description: The ID of the tool call. + type: + type: string + description: The type of tool call. This is always going to be `code_interpreter` for this type of tool call. + enum: ["code_interpreter"] + code_interpreter: + type: object + description: The Code Interpreter tool call definition. + required: + - input + - outputs + properties: + input: + type: string + description: The input to the Code Interpreter tool call. + outputs: + type: array + description: The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type. + items: + type: object + oneOf: + - $ref: "#/components/schemas/RunStepDetailsToolCallsCodeOutputLogsObject" + - $ref: "#/components/schemas/RunStepDetailsToolCallsCodeOutputImageObject" + x-oaiExpandable: true + required: + - id + - type + - code_interpreter + + RunStepDeltaStepDetailsToolCallsCodeObject: + title: Code interpreter tool call + type: object + description: Details of the Code Interpreter tool call the run step was involved in. + properties: + index: + type: integer + description: The index of the tool call in the tool calls array. + id: + type: string + description: The ID of the tool call. + type: + type: string + description: The type of tool call. This is always going to be `code_interpreter` for this type of tool call. + enum: ["code_interpreter"] + code_interpreter: + type: object + description: The Code Interpreter tool call definition. + properties: + input: + type: string + description: The input to the Code Interpreter tool call. + outputs: + type: array + description: The outputs from the Code Interpreter tool call. Code Interpreter can output one or more items, including text (`logs`) or images (`image`). Each of these are represented by a different object type. + items: + type: object + oneOf: + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject" + - $ref: "#/components/schemas/RunStepDeltaStepDetailsToolCallsCodeOutputImageObject" + x-oaiExpandable: true + required: + - index + - type + + RunStepDetailsToolCallsCodeOutputLogsObject: + title: Code Interpreter log output + type: object + description: Text output from the Code Interpreter tool call as part of a run step. + properties: + type: + description: Always `logs`. + type: string + enum: ["logs"] + logs: + type: string + description: The text output from the Code Interpreter tool call. + required: + - type + - logs + + RunStepDeltaStepDetailsToolCallsCodeOutputLogsObject: + title: Code interpreter log output + type: object + description: Text output from the Code Interpreter tool call as part of a run step. + properties: + index: + type: integer + description: The index of the output in the outputs array. + type: + description: Always `logs`. + type: string + enum: ["logs"] + logs: + type: string + description: The text output from the Code Interpreter tool call. + required: + - index + - type + + RunStepDetailsToolCallsCodeOutputImageObject: + title: Code Interpreter image output + type: object + properties: + type: + description: Always `image`. + type: string + enum: ["image"] + image: + type: object + properties: + file_id: + description: The [file](/docs/api-reference/files) ID of the image. + type: string + required: + - file_id + required: + - type + - image + + RunStepDeltaStepDetailsToolCallsCodeOutputImageObject: + title: Code interpreter image output + type: object + properties: + index: + type: integer + description: The index of the output in the outputs array. + type: + description: Always `image`. + type: string + enum: ["image"] + image: + type: object + properties: + file_id: + description: The [file](/docs/api-reference/files) ID of the image. + type: string + required: + - index + - type + + RunStepDetailsToolCallsFileSearchObject: + title: File search tool call + type: object + properties: + id: + type: string + description: The ID of the tool call object. + type: + type: string + description: The type of tool call. This is always going to be `file_search` for this type of tool call. + enum: ["file_search"] + file_search: + type: object + description: For now, this is always going to be an empty object. + x-oaiTypeLabel: map + required: + - id + - type + - file_search + + RunStepDeltaStepDetailsToolCallsFileSearchObject: + title: File search tool call + type: object + properties: + index: + type: integer + description: The index of the tool call in the tool calls array. + id: + type: string + description: The ID of the tool call object. + type: + type: string + description: The type of tool call. This is always going to be `file_search` for this type of tool call. + enum: ["file_search"] + file_search: + type: object + description: For now, this is always going to be an empty object. + x-oaiTypeLabel: map + required: + - index + - type + - file_search + + RunStepDetailsToolCallsFunctionObject: + type: object + title: Function tool call + properties: + id: + type: string + description: The ID of the tool call object. + type: + type: string + description: The type of tool call. This is always going to be `function` for this type of tool call. + enum: ["function"] + function: + type: object + description: The definition of the function that was called. + properties: + name: + type: string + description: The name of the function. + arguments: + type: string + description: The arguments passed to the function. + output: + type: string + description: The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet. + nullable: true + required: + - name + - arguments + - output + required: + - id + - type + - function + + RunStepDeltaStepDetailsToolCallsFunctionObject: + type: object + title: Function tool call + properties: + index: + type: integer + description: The index of the tool call in the tool calls array. + id: + type: string + description: The ID of the tool call object. + type: + type: string + description: The type of tool call. This is always going to be `function` for this type of tool call. + enum: ["function"] + function: + type: object + description: The definition of the function that was called. + properties: + name: + type: string + description: The name of the function. + arguments: + type: string + description: The arguments passed to the function. + output: + type: string + description: The output of the function. This will be `null` if the outputs have not been [submitted](/docs/api-reference/runs/submitToolOutputs) yet. + nullable: true + required: + - index + - type + + VectorStoreExpirationAfter: + type: object + title: Vector store expiration policy + description: The expiration policy for a vector store. + properties: + anchor: + description: "Anchor timestamp after which the expiration policy applies. Supported anchors: `last_active_at`." + type: string + enum: ["last_active_at"] + days: + description: The number of days after the anchor time that the vector store will expire. + type: integer + minimum: 1 + maximum: 365 + required: + - anchor + - days + + VectorStoreObject: + type: object + title: Vector store + description: A vector store is a collection of processed files can be used by the `file_search` tool. + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `vector_store`. + type: string + enum: ["vector_store"] + created_at: + description: The Unix timestamp (in seconds) for when the vector store was created. + type: integer + name: + description: The name of the vector store. + type: string + usage_bytes: + description: The total number of bytes used by the files in the vector store. + type: integer + file_counts: + type: object + properties: + in_progress: + description: The number of files that are currently being processed. + type: integer + completed: + description: The number of files that have been successfully processed. + type: integer + failed: + description: The number of files that have failed to process. + type: integer + cancelled: + description: The number of files that were cancelled. + type: integer + total: + description: The total number of files. + type: integer + required: + - in_progress + - completed + - failed + - cancelled + - total + status: + description: The status of the vector store, which can be either `expired`, `in_progress`, or `completed`. A status of `completed` indicates that the vector store is ready for use. + type: string + enum: ["expired", "in_progress", "completed"] + expires_after: + $ref: "#/components/schemas/VectorStoreExpirationAfter" + expires_at: + description: The Unix timestamp (in seconds) for when the vector store will expire. + type: integer + nullable: true + last_active_at: + description: The Unix timestamp (in seconds) for when the vector store was last active. + type: integer + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + required: + - id + - object + - usage_bytes + - created_at + - status + - last_active_at + - name + - file_counts + - metadata + x-oaiMeta: + name: The vector store object + beta: true + example: | + { + "id": "vs_123", + "object": "vector_store", + "created_at": 1698107661, + "usage_bytes": 123456, + "last_active_at": 1698107661, + "name": "my_vector_store", + "status": "completed", + "file_counts": { + "in_progress": 0, + "completed": 100, + "cancelled": 0, + "failed": 0, + "total": 100 + }, + "metadata": {}, + "last_used_at": 1698107661 + } + + CreateVectorStoreRequest: + type: object + additionalProperties: false + properties: + file_ids: + description: A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files. + type: array + maxItems: 500 + items: + type: string + name: + description: The name of the vector store. + type: string + expires_after: + $ref: "#/components/schemas/VectorStoreExpirationAfter" + chunking_strategy: + type: object + description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. Only applicable if `file_ids` is non-empty. + oneOf: + - $ref: "#/components/schemas/AutoChunkingStrategyRequestParam" + - $ref: "#/components/schemas/StaticChunkingStrategyRequestParam" + x-oaiExpandable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + + UpdateVectorStoreRequest: + type: object + additionalProperties: false + properties: + name: + description: The name of the vector store. + type: string + nullable: true + expires_after: + $ref: "#/components/schemas/VectorStoreExpirationAfter" + nullable: true + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + + ListVectorStoresResponse: + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/VectorStoreObject" + first_id: + type: string + example: "vs_abc123" + last_id: + type: string + example: "vs_abc456" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + + DeleteVectorStoreResponse: + type: object + properties: + id: + type: string + deleted: + type: boolean + object: + type: string + enum: [vector_store.deleted] + required: + - id + - object + - deleted + + VectorStoreFileObject: + type: object + title: Vector store files + description: A list of files attached to a vector store. + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `vector_store.file`. + type: string + enum: ["vector_store.file"] + usage_bytes: + description: The total vector store usage in bytes. Note that this may be different from the original file size. + type: integer + created_at: + description: The Unix timestamp (in seconds) for when the vector store file was created. + type: integer + vector_store_id: + description: The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to. + type: string + status: + description: The status of the vector store file, which can be either `in_progress`, `completed`, `cancelled`, or `failed`. The status `completed` indicates that the vector store file is ready for use. + type: string + enum: ["in_progress", "completed", "cancelled", "failed"] + last_error: + type: object + description: The last error associated with this vector store file. Will be `null` if there are no errors. + nullable: true + properties: + code: + type: string + description: One of `server_error` or `rate_limit_exceeded`. + enum: ["server_error", "unsupported_file", "invalid_file"] + message: + type: string + description: A human-readable description of the error. + required: + - code + - message + chunking_strategy: + type: object + description: The strategy used to chunk the file. + oneOf: + - $ref: "#/components/schemas/StaticChunkingStrategyResponseParam" + - $ref: "#/components/schemas/OtherChunkingStrategyResponseParam" + x-oaiExpandable: true + required: + - id + - object + - usage_bytes + - created_at + - vector_store_id + - status + - last_error + x-oaiMeta: + name: The vector store file object + beta: true + example: | + { + "id": "file-abc123", + "object": "vector_store.file", + "usage_bytes": 1234, + "created_at": 1698107661, + "vector_store_id": "vs_abc123", + "status": "completed", + "last_error": null, + "chunking_strategy": { + "type": "static", + "static": { + "max_chunk_size_tokens": 800, + "chunk_overlap_tokens": 400 + } + } + } + + OtherChunkingStrategyResponseParam: + type: object + title: Other Chunking Strategy + description: This is returned when the chunking strategy is unknown. Typically, this is because the file was indexed before the `chunking_strategy` concept was introduced in the API. + additionalProperties: false + properties: + type: + type: string + description: Always `other`. + enum: ["other"] + required: + - type + + StaticChunkingStrategyResponseParam: + type: object + title: Static Chunking Strategy + additionalProperties: false + properties: + type: + type: string + description: Always `static`. + enum: ["static"] + static: + $ref: "#/components/schemas/StaticChunkingStrategy" + required: + - type + - static + + StaticChunkingStrategy: + type: object + additionalProperties: false + properties: + max_chunk_size_tokens: + type: integer + minimum: 100 + maximum: 4096 + description: The maximum number of tokens in each chunk. The default value is `800`. The minimum value is `100` and the maximum value is `4096`. + chunk_overlap_tokens: + type: integer + description: | + The number of tokens that overlap between chunks. The default value is `400`. + + Note that the overlap must not exceed half of `max_chunk_size_tokens`. + required: + - max_chunk_size_tokens + - chunk_overlap_tokens + + AutoChunkingStrategyRequestParam: + type: object + title: Auto Chunking Strategy + description: The default strategy. This strategy currently uses a `max_chunk_size_tokens` of `800` and `chunk_overlap_tokens` of `400`. + additionalProperties: false + properties: + type: + type: string + description: Always `auto`. + enum: ["auto"] + required: + - type + + StaticChunkingStrategyRequestParam: + type: object + title: Static Chunking Strategy + additionalProperties: false + properties: + type: + type: string + description: Always `static`. + enum: ["static"] + static: + $ref: "#/components/schemas/StaticChunkingStrategy" + required: + - type + - static + + ChunkingStrategyRequestParam: + type: object + description: The chunking strategy used to chunk the file(s). If not set, will use the `auto` strategy. + oneOf: + - $ref: "#/components/schemas/AutoChunkingStrategyRequestParam" + - $ref: "#/components/schemas/StaticChunkingStrategyRequestParam" + x-oaiExpandable: true + + CreateVectorStoreFileRequest: + type: object + additionalProperties: false + properties: + file_id: + description: A [File](/docs/api-reference/files) ID that the vector store should use. Useful for tools like `file_search` that can access files. + type: string + chunking_strategy: + $ref: "#/components/schemas/ChunkingStrategyRequestParam" + required: + - file_id + + ListVectorStoreFilesResponse: + properties: + object: + type: string + example: "list" + data: + type: array + items: + $ref: "#/components/schemas/VectorStoreFileObject" + first_id: + type: string + example: "file-abc123" + last_id: + type: string + example: "file-abc456" + has_more: + type: boolean + example: false + required: + - object + - data + - first_id + - last_id + - has_more + + DeleteVectorStoreFileResponse: + type: object + properties: + id: + type: string + deleted: + type: boolean + object: + type: string + enum: [vector_store.file.deleted] + required: + - id + - object + - deleted + + VectorStoreFileBatchObject: + type: object + title: Vector store file batch + description: A batch of files attached to a vector store. + properties: + id: + description: The identifier, which can be referenced in API endpoints. + type: string + object: + description: The object type, which is always `vector_store.file_batch`. + type: string + enum: ["vector_store.files_batch"] + created_at: + description: The Unix timestamp (in seconds) for when the vector store files batch was created. + type: integer + vector_store_id: + description: The ID of the [vector store](/docs/api-reference/vector-stores/object) that the [File](/docs/api-reference/files) is attached to. + type: string + status: + description: The status of the vector store files batch, which can be either `in_progress`, `completed`, `cancelled` or `failed`. + type: string + enum: ["in_progress", "completed", "cancelled", "failed"] + file_counts: + type: object + properties: + in_progress: + description: The number of files that are currently being processed. + type: integer + completed: + description: The number of files that have been processed. + type: integer + failed: + description: The number of files that have failed to process. + type: integer + cancelled: + description: The number of files that where cancelled. + type: integer + total: + description: The total number of files. + type: integer + required: + - in_progress + - completed + - cancelled + - failed + - total + required: + - id + - object + - created_at + - vector_store_id + - status + - file_counts + x-oaiMeta: + name: The vector store files batch object + beta: true + example: | + { + "id": "vsfb_123", + "object": "vector_store.files_batch", + "created_at": 1698107661, + "vector_store_id": "vs_abc123", + "status": "completed", + "file_counts": { + "in_progress": 0, + "completed": 100, + "failed": 0, + "cancelled": 0, + "total": 100 + } + } + + CreateVectorStoreFileBatchRequest: + type: object + additionalProperties: false + properties: + file_ids: + description: A list of [File](/docs/api-reference/files) IDs that the vector store should use. Useful for tools like `file_search` that can access files. + type: array + minItems: 1 + maxItems: 500 + items: + type: string + chunking_strategy: + $ref: "#/components/schemas/ChunkingStrategyRequestParam" + required: + - file_ids + + AssistantStreamEvent: + description: | + Represents an event emitted when streaming a Run. + + Each event in a server-sent events stream has an `event` and `data` property: + + ``` + event: thread.created + data: {"id": "thread_123", "object": "thread", ...} + ``` + + We emit events whenever a new object is created, transitions to a new state, or is being + streamed in parts (deltas). For example, we emit `thread.run.created` when a new run + is created, `thread.run.completed` when a run completes, and so on. When an Assistant chooses + to create a message during a run, we emit a `thread.message.created event`, a + `thread.message.in_progress` event, many `thread.message.delta` events, and finally a + `thread.message.completed` event. + + We may add additional events over time, so we recommend handling unknown events gracefully + in your code. See the [Assistants API quickstart](/docs/assistants/overview) to learn how to + integrate the Assistants API with streaming. + oneOf: + - $ref: "#/components/schemas/ThreadStreamEvent" + - $ref: "#/components/schemas/RunStreamEvent" + - $ref: "#/components/schemas/RunStepStreamEvent" + - $ref: "#/components/schemas/MessageStreamEvent" + - $ref: "#/components/schemas/ErrorEvent" + - $ref: "#/components/schemas/DoneEvent" + x-oaiMeta: + name: Assistant stream events + beta: true + + ThreadStreamEvent: + oneOf: + - type: object + properties: + event: + type: string + enum: ["thread.created"] + data: + $ref: "#/components/schemas/ThreadObject" + required: + - event + - data + description: Occurs when a new [thread](/docs/api-reference/threads/object) is created. + x-oaiMeta: + dataDescription: "`data` is a [thread](/docs/api-reference/threads/object)" + + RunStreamEvent: + oneOf: + - type: object + properties: + event: + type: string + enum: ["thread.run.created"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a new [run](/docs/api-reference/runs/object) is created. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.queued"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) moves to a `queued` status. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.in_progress"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) moves to an `in_progress` status. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.requires_action"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) moves to a `requires_action` status. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.completed"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) is completed. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.incomplete"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) ends with status `incomplete`. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.failed"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) fails. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.cancelling"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) moves to a `cancelling` status. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.cancelled"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) is cancelled. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.expired"] + data: + $ref: "#/components/schemas/RunObject" + required: + - event + - data + description: Occurs when a [run](/docs/api-reference/runs/object) expires. + x-oaiMeta: + dataDescription: "`data` is a [run](/docs/api-reference/runs/object)" + + RunStepStreamEvent: + oneOf: + - type: object + properties: + event: + type: string + enum: ["thread.run.step.created"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) is created. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.in_progress"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) moves to an `in_progress` state. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.delta"] + data: + $ref: "#/components/schemas/RunStepDeltaObject" + required: + - event + - data + description: Occurs when parts of a [run step](/docs/api-reference/runs/step-object) are being streamed. + x-oaiMeta: + dataDescription: "`data` is a [run step delta](/docs/api-reference/assistants-streaming/run-step-delta-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.completed"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) is completed. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.failed"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) fails. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.cancelled"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) is cancelled. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + - type: object + properties: + event: + type: string + enum: ["thread.run.step.expired"] + data: + $ref: "#/components/schemas/RunStepObject" + required: + - event + - data + description: Occurs when a [run step](/docs/api-reference/runs/step-object) expires. + x-oaiMeta: + dataDescription: "`data` is a [run step](/docs/api-reference/runs/step-object)" + + MessageStreamEvent: + oneOf: + - type: object + properties: + event: + type: string + enum: ["thread.message.created"] + data: + $ref: "#/components/schemas/MessageObject" + required: + - event + - data + description: Occurs when a [message](/docs/api-reference/messages/object) is created. + x-oaiMeta: + dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" + - type: object + properties: + event: + type: string + enum: ["thread.message.in_progress"] + data: + $ref: "#/components/schemas/MessageObject" + required: + - event + - data + description: Occurs when a [message](/docs/api-reference/messages/object) moves to an `in_progress` state. + x-oaiMeta: + dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" + - type: object + properties: + event: + type: string + enum: ["thread.message.delta"] + data: + $ref: "#/components/schemas/MessageDeltaObject" + required: + - event + - data + description: Occurs when parts of a [Message](/docs/api-reference/messages/object) are being streamed. + x-oaiMeta: + dataDescription: "`data` is a [message delta](/docs/api-reference/assistants-streaming/message-delta-object)" + - type: object + properties: + event: + type: string + enum: ["thread.message.completed"] + data: + $ref: "#/components/schemas/MessageObject" + required: + - event + - data + description: Occurs when a [message](/docs/api-reference/messages/object) is completed. + x-oaiMeta: + dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" + - type: object + properties: + event: + type: string + enum: ["thread.message.incomplete"] + data: + $ref: "#/components/schemas/MessageObject" + required: + - event + - data + description: Occurs when a [message](/docs/api-reference/messages/object) ends before it is completed. + x-oaiMeta: + dataDescription: "`data` is a [message](/docs/api-reference/messages/object)" + + ErrorEvent: + type: object + properties: + event: + type: string + enum: ["error"] + data: + $ref: "#/components/schemas/Error" + required: + - event + - data + description: Occurs when an [error](/docs/guides/error-codes/api-errors) occurs. This can happen due to an internal server error or a timeout. + x-oaiMeta: + dataDescription: "`data` is an [error](/docs/guides/error-codes/api-errors)" + + DoneEvent: + type: object + properties: + event: + type: string + enum: ["done"] + data: + type: string + enum: ["[DONE]"] + required: + - event + - data + description: Occurs when a stream ends. + x-oaiMeta: + dataDescription: "`data` is `[DONE]`" + + Batch: + type: object + properties: + id: + type: string + object: + type: string + enum: [batch] + description: The object type, which is always `batch`. + endpoint: + type: string + description: The OpenAI API endpoint used by the batch. + + errors: + type: object + properties: + object: + type: string + description: The object type, which is always `list`. + data: + type: array + items: + type: object + properties: + code: + type: string + description: An error code identifying the error type. + message: + type: string + description: A human-readable message providing more details about the error. + param: + type: string + description: The name of the parameter that caused the error, if applicable. + nullable: true + line: + type: integer + description: The line number of the input file where the error occurred, if applicable. + nullable: true + input_file_id: + type: string + description: The ID of the input file for the batch. + completion_window: + type: string + description: The time frame within which the batch should be processed. + status: + type: string + description: The current status of the batch. + enum: + - validating + - failed + - in_progress + - finalizing + - completed + - expired + - cancelling + - cancelled + output_file_id: + type: string + description: The ID of the file containing the outputs of successfully executed requests. + error_file_id: + type: string + description: The ID of the file containing the outputs of requests with errors. + created_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch was created. + in_progress_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch started processing. + expires_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch will expire. + finalizing_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch started finalizing. + completed_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch was completed. + failed_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch failed. + expired_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch expired. + cancelling_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch started cancelling. + cancelled_at: + type: integer + description: The Unix timestamp (in seconds) for when the batch was cancelled. + request_counts: + type: object + properties: + total: + type: integer + description: Total number of requests in the batch. + completed: + type: integer + description: Number of requests that have been completed successfully. + failed: + type: integer + description: Number of requests that have failed. + required: + - total + - completed + - failed + description: The request counts for different statuses within the batch. + metadata: + description: *metadata_description + type: object + x-oaiTypeLabel: map + nullable: true + required: + - id + - object + - endpoint + - input_file_id + - completion_window + - status + - created_at + x-oaiMeta: + name: The batch object + example: *batch_object + + BatchRequestInput: + type: object + description: The per-line object of the batch input file + properties: + custom_id: + type: string + description: A developer-provided per-request id that will be used to match outputs to inputs. Must be unique for each request in a batch. + method: + type: string + enum: ["POST"] + description: The HTTP method to be used for the request. Currently only `POST` is supported. + url: + type: string + description: The OpenAI API relative URL to be used for the request. Currently `/v1/chat/completions`, `/v1/embeddings`, and `/v1/completions` are supported. + x-oaiMeta: + name: The request input object + example: | + {"custom_id": "request-1", "method": "POST", "url": "/v1/chat/completions", "body": {"model": "gpt-4o-mini", "messages": [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is 2+2?"}]}} + + BatchRequestOutput: + type: object + description: The per-line object of the batch output and error files + properties: + id: + type: string + custom_id: + type: string + description: A developer-provided per-request id that will be used to match outputs to inputs. + response: + type: object + nullable: true + properties: + status_code: + type: integer + description: The HTTP status code of the response + request_id: + type: string + description: An unique identifier for the OpenAI API request. Please include this request ID when contacting support. + body: + type: object + x-oaiTypeLabel: map + description: The JSON body of the response + error: + type: object + nullable: true + description: For requests that failed with a non-HTTP error, this will contain more information on the cause of the failure. + properties: + code: + type: string + description: A machine-readable error code. + message: + type: string + description: A human-readable error message. + x-oaiMeta: + name: The request output object + example: | + {"id": "batch_req_wnaDys", "custom_id": "request-2", "response": {"status_code": 200, "request_id": "req_c187b3", "body": {"id": "chatcmpl-9758Iw", "object": "chat.completion", "created": 1711475054, "model": "gpt-4o-mini", "choices": [{"index": 0, "message": {"role": "assistant", "content": "2 + 2 equals 4."}, "finish_reason": "stop"}], "usage": {"prompt_tokens": 24, "completion_tokens": 15, "total_tokens": 39}, "system_fingerprint": null}}, "error": null} + + ListBatchesResponse: + type: object + properties: + data: + type: array + items: + $ref: "#/components/schemas/Batch" + first_id: + type: string + example: "batch_abc123" + last_id: + type: string + example: "batch_abc456" + has_more: + type: boolean + object: + type: string + enum: [list] + required: + - object + - data + - has_more + + AuditLogActorServiceAccount: + type: object + description: The service account that performed the audit logged action. + properties: + id: + type: string + description: The service account id. + + AuditLogActorUser: + type: object + description: The user who performed the audit logged action. + properties: + id: + type: string + description: The user id. + email: + type: string + description: The user email. + + AuditLogActorApiKey: + type: object + description: The API Key used to perform the audit logged action. + properties: + id: + type: string + description: The tracking id of the API key. + type: + type: string + description: The type of API key. Can be either `user` or `service_account`. + enum: ["user", "service_account"] + user: + $ref: "#/components/schemas/AuditLogActorUser" + service_account: + $ref: "#/components/schemas/AuditLogActorServiceAccount" + + AuditLogActorSession: + type: object + description: The session in which the audit logged action was performed. + properties: + user: + $ref: "#/components/schemas/AuditLogActorUser" + ip_address: + type: string + description: The IP address from which the action was performed. + + AuditLogActor: + type: object + description: The actor who performed the audit logged action. + properties: + type: + type: string + description: The type of actor. Is either `session` or `api_key`. + enum: ["session", "api_key"] + session: + type: object + $ref: "#/components/schemas/AuditLogActorSession" + api_key: + type: object + $ref: "#/components/schemas/AuditLogActorApiKey" + + AuditLogEventType: + type: string + description: The event type. + x-oaiExpandable: true + enum: + - api_key.created + - api_key.updated + - api_key.deleted + - invite.sent + - invite.accepted + - invite.deleted + - login.succeeded + - login.failed + - logout.succeeded + - logout.failed + - organization.updated + - project.created + - project.updated + - project.archived + - service_account.created + - service_account.updated + - service_account.deleted + - user.added + - user.updated + - user.deleted + + AuditLog: + type: object + description: A log of a user action or configuration change within this organization. + properties: + id: + type: string + description: The ID of this log. + type: + $ref: "#/components/schemas/AuditLogEventType" + + effective_at: + type: integer + description: The Unix timestamp (in seconds) of the event. + project: + type: object + description: The project that the action was scoped to. Absent for actions not scoped to projects. + properties: + id: + type: string + description: The project ID. + name: + type: string + description: The project title. + actor: + $ref: "#/components/schemas/AuditLogActor" + api_key.created: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The tracking ID of the API key. + data: + type: object + description: The payload used to create the API key. + properties: + scopes: + type: array + items: + type: string + description: A list of scopes allowed for the API key, e.g. `["api.model.request"]` + api_key.updated: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The tracking ID of the API key. + changes_requested: + type: object + description: The payload used to update the API key. + properties: + scopes: + type: array + items: + type: string + description: A list of scopes allowed for the API key, e.g. `["api.model.request"]` + api_key.deleted: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The tracking ID of the API key. + invite.sent: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The ID of the invite. + data: + type: object + description: The payload used to create the invite. + properties: + email: + type: string + description: The email invited to the organization. + role: + type: string + description: The role the email was invited to be. Is either `owner` or `member`. + invite.accepted: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The ID of the invite. + invite.deleted: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The ID of the invite. + login.failed: + type: object + description: The details for events with this `type`. + properties: + error_code: + type: string + description: The error code of the failure. + error_message: + type: string + description: The error message of the failure. + logout.failed: + type: object + description: The details for events with this `type`. + properties: + error_code: + type: string + description: The error code of the failure. + error_message: + type: string + description: The error message of the failure. + organization.updated: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The organization ID. + changes_requested: + type: object + description: The payload used to update the organization settings. + properties: + title: + type: string + description: The organization title. + description: + type: string + description: The organization description. + name: + type: string + description: The organization name. + settings: + type: object + properties: + threads_ui_visibility: + type: string + description: Visibility of the threads page which shows messages created with the Assistants API and Playground. One of `ANY_ROLE`, `OWNERS`, or `NONE`. + usage_dashboard_visibility: + type: string + description: Visibility of the usage dashboard which shows activity and costs for your organization. One of `ANY_ROLE` or `OWNERS`. + project.created: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The project ID. + data: + type: object + description: The payload used to create the project. + properties: + name: + type: string + description: The project name. + title: + type: string + description: The title of the project as seen on the dashboard. + project.updated: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The project ID. + changes_requested: + type: object + description: The payload used to update the project. + properties: + title: + type: string + description: The title of the project as seen on the dashboard. + project.archived: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The project ID. + service_account.created: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The service account ID. + data: + type: object + description: The payload used to create the service account. + properties: + role: + type: string + description: The role of the service account. Is either `owner` or `member`. + service_account.updated: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The service account ID. + changes_requested: + type: object + description: The payload used to updated the service account. + properties: + role: + type: string + description: The role of the service account. Is either `owner` or `member`. + service_account.deleted: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The service account ID. + user.added: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The user ID. + data: + type: object + description: The payload used to add the user to the project. + properties: + role: + type: string + description: The role of the user. Is either `owner` or `member`. + user.updated: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The project ID. + changes_requested: + type: object + description: The payload used to update the user. + properties: + role: + type: string + description: The role of the user. Is either `owner` or `member`. + user.deleted: + type: object + description: The details for events with this `type`. + properties: + id: + type: string + description: The user ID. + required: + - id + - type + - effective_at + - actor + x-oaiMeta: + name: The audit log object + example: | + { + "id": "req_xxx_20240101", + "type": "api_key.created", + "effective_at": 1720804090, + "actor": { + "type": "session", + "session": { + "user": { + "id": "user-xxx", + "email": "user@example.com" + }, + "ip_address": "127.0.0.1", + "user_agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36" + } + }, + "api_key.created": { + "id": "key_xxxx", + "data": { + "scopes": ["resource.operation"] + } + } + } + + ListAuditLogsResponse: + type: object + properties: + object: + type: string + enum: [list] + data: + type: array + items: + $ref: "#/components/schemas/AuditLog" + first_id: + type: string + example: "audit_log-defb456h8dks" + last_id: + type: string + example: "audit_log-hnbkd8s93s" + has_more: + type: boolean + + required: + - object + - data + - first_id + - last_id + - has_more + + Invite: + type: object + description: Represents an individual `invite` to the organization. + properties: + object: + type: string + enum: [organization.invite] + description: The object type, which is always `organization.invite` + id: + type: string + description: The identifier, which can be referenced in API endpoints + email: + type: string + description: The email address of the individual to whom the invite was sent + role: + type: string + enum: [owner, reader] + description: "`owner` or `reader`" + status: + type: string + enum: [accepted, expired, pending] + description: "`accepted`,`expired`, or `pending`" + invited_at: + type: integer + description: The Unix timestamp (in seconds) of when the invite was sent. + expires_at: + type: integer + description: The Unix timestamp (in seconds) of when the invite expires. + accepted_at: + type: integer + description: The Unix timestamp (in seconds) of when the invite was accepted. + + required: + - object + - id + - email + - role + - status + - invited_at + - expires_at + x-oaiMeta: + name: The invite object + example: | + { + "object": "organization.invite", + "id": "invite-abc", + "email": "user@example.com", + "role": "owner", + "status": "accepted", + "invited_at": 1711471533, + "expires_at": 1711471533, + "accepted_at": 1711471533 + } + + InviteListResponse: + type: object + properties: + object: + type: string + enum: [list] + description: The object type, which is always `list` + data: + type: array + items: + $ref: "#/components/schemas/Invite" + first_id: + type: string + description: The first `invite_id` in the retrieved `list` + last_id: + type: string + description: The last `invite_id` in the retrieved `list` + has_more: + type: boolean + description: The `has_more` property is used for pagination to indicate there are additional results. + required: + - object + - data + + InviteRequest: + type: object + properties: + email: + type: string + description: "Send an email to this address" + role: + type: string + enum: [reader, owner] + description: "`owner` or `reader`" + required: + - email + - role + + InviteDeleteResponse: + type: object + properties: + object: + type: string + enum: [organization.invite.deleted] + description: The object type, which is always `organization.invite.deleted` + id: + type: string + deleted: + type: boolean + required: + - object + - id + - deleted + + User: + type: object + description: Represents an individual `user` within an organization. + properties: + object: + type: string + enum: [organization.user] + description: The object type, which is always `organization.user` + id: + type: string + description: The identifier, which can be referenced in API endpoints + name: + type: string + description: The name of the user + email: + type: string + description: The email address of the user + role: + type: string + enum: [owner, reader] + description: "`owner` or `reader`" + added_at: + type: integer + description: The Unix timestamp (in seconds) of when the user was added. + required: + - object + - id + - name + - email + - role + - added_at + x-oaiMeta: + name: The user object + example: | + { + "object": "organization.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + + UserListResponse: + type: object + properties: + object: + type: string + enum: [list] + data: + type: array + items: + $ref: "#/components/schemas/User" + first_id: + type: string + last_id: + type: string + has_more: + type: boolean + required: + - object + - data + - first_id + - last_id + - has_more + + UserRoleUpdateRequest: + type: object + properties: + role: + type: string + enum: [owner, reader] + description: "`owner` or `reader`" + required: + - role + + UserDeleteResponse: + type: object + properties: + object: + type: string + enum: [organization.user.deleted] + id: + type: string + deleted: + type: boolean + required: + - object + - id + - deleted + + Project: + type: object + description: Represents an individual project. + properties: + id: + type: string + description: The identifier, which can be referenced in API endpoints + object: + type: string + enum: [organization.project] + description: The object type, which is always `organization.project` + name: + type: string + description: The name of the project. This appears in reporting. + created_at: + type: integer + description: The Unix timestamp (in seconds) of when the project was created. + archived_at: + type: integer + nullable: true + description: The Unix timestamp (in seconds) of when the project was archived or `null`. + status: + type: string + enum: [active, archived] + description: "`active` or `archived`" + required: + - id + - object + - name + - created_at + - status + x-oaiMeta: + name: The project object + example: | + { + "id": "proj_abc", + "object": "organization.project", + "name": "Project example", + "created_at": 1711471533, + "archived_at": null, + "status": "active" + } + + ProjectListResponse: + type: object + properties: + object: + type: string + enum: [list] + data: + type: array + items: + $ref: "#/components/schemas/Project" + first_id: + type: string + last_id: + type: string + has_more: + type: boolean + required: + - object + - data + - first_id + - last_id + - has_more + + ProjectCreateRequest: + type: object + properties: + name: + type: string + description: The friendly name of the project, this name appears in reports. + required: + - name + + ProjectUpdateRequest: + type: object + properties: + name: + type: string + description: The updated name of the project, this name appears in reports. + required: + - name + + DefaultProjectErrorResponse: + type: object + properties: + code: + type: integer + message: + type: string + required: + - code + - message + + ProjectUser: + type: object + description: Represents an individual user in a project. + properties: + object: + type: string + enum: [organization.project.user] + description: The object type, which is always `organization.project.user` + id: + type: string + description: The identifier, which can be referenced in API endpoints + name: + type: string + description: The name of the user + email: + type: string + description: The email address of the user + role: + type: string + enum: [owner, member] + description: "`owner` or `member`" + added_at: + type: integer + description: The Unix timestamp (in seconds) of when the project was added. + + required: + - object + - id + - name + - email + - role + - added_at + x-oaiMeta: + name: The project user object + example: | + { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + + ProjectUserListResponse: + type: object + properties: + object: + type: string + data: + type: array + items: + $ref: "#/components/schemas/ProjectUser" + first_id: + type: string + last_id: + type: string + has_more: + type: boolean + required: + - object + - data + - first_id + - last_id + - has_more + + ProjectUserCreateRequest: + type: object + properties: + user_id: + type: string + description: The ID of the user. + role: + type: string + enum: [owner, member] + description: "`owner` or `member`" + required: + - user_id + - role + + ProjectUserUpdateRequest: + type: object + properties: + role: + type: string + enum: [owner, member] + description: "`owner` or `member`" + required: + - role + + ProjectUserDeleteResponse: + type: object + properties: + object: + type: string + enum: [organization.project.user.deleted] + id: + type: string + deleted: + type: boolean + required: + - object + - id + - deleted + + ProjectServiceAccount: + type: object + description: Represents an individual service account in a project. + properties: + object: + type: string + enum: [organization.project.service_account] + description: The object type, which is always `organization.project.service_account` + id: + type: string + description: The identifier, which can be referenced in API endpoints + name: + type: string + description: The name of the service account + role: + type: string + enum: [owner, member] + description: "`owner` or `member`" + created_at: + type: integer + description: The Unix timestamp (in seconds) of when the service account was created + required: + - object + - id + - name + - role + - created_at + x-oaiMeta: + name: The project service account object + example: | + { + "object": "organization.project.service_account", + "id": "svc_acct_abc", + "name": "Service Account", + "role": "owner", + "created_at": 1711471533 + } + + ProjectServiceAccountListResponse: + type: object + properties: + object: + type: string + enum: [list] + data: + type: array + items: + $ref: "#/components/schemas/ProjectServiceAccount" + first_id: + type: string + last_id: + type: string + has_more: + type: boolean + required: + - object + - data + - first_id + - last_id + - has_more + + ProjectServiceAccountCreateRequest: + type: object + properties: + name: + type: string + description: The name of the service account being created. + required: + - name + + ProjectServiceAccountCreateResponse: + type: object + properties: + object: + type: string + enum: [organization.project.service_account] + id: + type: string + name: + type: string + role: + type: string + enum: [member] + description: Service accounts can only have one role of type `member` + created_at: + type: integer + api_key: + $ref: "#/components/schemas/ProjectServiceAccountApiKey" + required: + - object + - id + - name + - role + - created_at + - api_key + + ProjectServiceAccountApiKey: + type: object + properties: + object: + type: string + enum: [organization.project.service_account.api_key] + description: The object type, which is always `organization.project.service_account.api_key` + + value: + type: string + name: + type: string + created_at: + type: integer + id: + type: string + required: + - object + - value + - name + - created_at + - id + + ProjectServiceAccountDeleteResponse: + type: object + properties: + object: + type: string + enum: [organization.project.service_account.deleted] + id: + type: string + deleted: + type: boolean + required: + - object + - id + - deleted + + ProjectApiKey: + type: object + description: Represents an individual API key in a project. + properties: + object: + type: string + enum: [organization.project.api_key] + description: The object type, which is always `organization.project.api_key` + redacted_value: + type: string + description: The redacted value of the API key + name: + type: string + description: The name of the API key + created_at: + type: integer + description: The Unix timestamp (in seconds) of when the API key was created + id: + type: string + description: The identifier, which can be referenced in API endpoints + owner: + type: object + properties: + type: + type: string + enum: [user, service_account] + description: "`user` or `service_account`" + user: + $ref: "#/components/schemas/ProjectUser" + service_account: + $ref: "#/components/schemas/ProjectServiceAccount" + required: + - object + - redacted_value + - name + - created_at + - id + - owner + x-oaiMeta: + name: The project API key object + example: | + { + "object": "organization.project.api_key", + "redacted_value": "sk-abc...def", + "name": "My API Key", + "created_at": 1711471533, + "id": "key_abc", + "owner": { + "type": "user", + "user": { + "object": "organization.project.user", + "id": "user_abc", + "name": "First Last", + "email": "user@example.com", + "role": "owner", + "added_at": 1711471533 + } + } + } + + ProjectApiKeyListResponse: + type: object + properties: + object: + type: string + enum: [list] + data: + type: array + items: + $ref: "#/components/schemas/ProjectApiKey" + first_id: + type: string + last_id: + type: string + has_more: + type: boolean + required: + - object + - data + - first_id + - last_id + - has_more + + ProjectApiKeyDeleteResponse: + type: object + properties: + object: + type: string + enum: [organization.project.api_key.deleted] + id: + type: string + deleted: + type: boolean + required: + - object + - id + - deleted + +security: + - ApiKeyAuth: [] + +x-oaiMeta: + navigationGroups: + - id: endpoints + title: Endpoints + - id: assistants + title: Assistants + - id: administration + title: Administration + - id: legacy + title: Legacy + groups: + # > General Notes + # The `groups` section is used to generate the API reference pages and navigation, in the same + # order listed below. Additionally, each `group` can have a list of `sections`, each of which + # will become a navigation subroute and subsection under the group. Each section has: + # - `type`: Currently, either an `endpoint` or `object`, depending on how the section needs to + # be rendered + # - `key`: The reference key that can be used to lookup the section definition + # - `path`: The path (url) of the section, which is used to generate the navigation link. + # + # > The `object` sections maps to a schema component and the following fields are read for rendering + # - `x-oaiMeta.name`: The name of the object, which will become the section title + # - `x-oaiMeta.example`: The example object, which will be used to generate the example sample (always JSON) + # - `description`: The description of the object, which will be used to generate the section description + # + # > The `endpoint` section maps to an operation path and the following fields are read for rendering: + # - `x-oaiMeta.name`: The name of the endpoint, which will become the section title + # - `x-oaiMeta.examples`: The endpoint examples, which can be an object (meaning a single variation, most + # endpoints, or an array of objects, meaning multiple variations, e.g. the + # chat completion and completion endpoints, with streamed and non-streamed examples. + # - `x-oaiMeta.returns`: text describing what the endpoint returns. + # - `summary`: The summary of the endpoint, which will be used to generate the section description + - id: audio + title: Audio + description: | + Learn how to turn audio into text or text into audio. + + Related guide: [Speech to text](/docs/guides/speech-to-text) + navigationGroup: endpoints + sections: + - type: endpoint + key: createSpeech + path: createSpeech + - type: endpoint + key: createTranscription + path: createTranscription + - type: endpoint + key: createTranslation + path: createTranslation + - type: object + key: CreateTranscriptionResponseJson + path: json-object + - type: object + key: CreateTranscriptionResponseVerboseJson + path: verbose-json-object + - id: chat + title: Chat + description: | + Given a list of messages comprising a conversation, the model will return a response. + + Related guide: [Chat Completions](/docs/guides/text-generation) + navigationGroup: endpoints + sections: + - type: endpoint + key: createChatCompletion + path: create + - type: object + key: CreateChatCompletionResponse + path: object + - type: object + key: CreateChatCompletionStreamResponse + path: streaming + - id: embeddings + title: Embeddings + description: | + Get a vector representation of a given input that can be easily consumed by machine learning models and algorithms. + + Related guide: [Embeddings](/docs/guides/embeddings) + navigationGroup: endpoints + sections: + - type: endpoint + key: createEmbedding + path: create + - type: object + key: Embedding + path: object + - id: fine-tuning + title: Fine-tuning + description: | + Manage fine-tuning jobs to tailor a model to your specific training data. + + Related guide: [Fine-tune models](/docs/guides/fine-tuning) + navigationGroup: endpoints + sections: + - type: endpoint + key: createFineTuningJob + path: create + - type: endpoint + key: listPaginatedFineTuningJobs + path: list + - type: endpoint + key: listFineTuningEvents + path: list-events + - type: endpoint + key: listFineTuningJobCheckpoints + path: list-checkpoints + - type: endpoint + key: retrieveFineTuningJob + path: retrieve + - type: endpoint + key: cancelFineTuningJob + path: cancel + - type: object + key: FinetuneChatRequestInput + path: chat-input + - type: object + key: FinetuneCompletionRequestInput + path: completions-input + - type: object + key: FineTuningJob + path: object + - type: object + key: FineTuningJobEvent + path: event-object + - type: object + key: FineTuningJobCheckpoint + path: checkpoint-object + - id: batch + title: Batch + description: | + Create large batches of API requests for asynchronous processing. The Batch API returns completions within 24 hours for a 50% discount. + + Related guide: [Batch](/docs/guides/batch) + navigationGroup: endpoints + sections: + - type: endpoint + key: createBatch + path: create + - type: endpoint + key: retrieveBatch + path: retrieve + - type: endpoint + key: cancelBatch + path: cancel + - type: endpoint + key: listBatches + path: list + - type: object + key: Batch + path: object + - type: object + key: BatchRequestInput + path: request-input + - type: object + key: BatchRequestOutput + path: request-output + - id: files + title: Files + description: | + Files are used to upload documents that can be used with features like [Assistants](/docs/api-reference/assistants), [Fine-tuning](/docs/api-reference/fine-tuning), and [Batch API](/docs/guides/batch). + navigationGroup: endpoints + sections: + - type: endpoint + key: createFile + path: create + - type: endpoint + key: listFiles + path: list + - type: endpoint + key: retrieveFile + path: retrieve + - type: endpoint + key: deleteFile + path: delete + - type: endpoint + key: downloadFile + path: retrieve-contents + - type: object + key: OpenAIFile + path: object + - id: uploads + title: Uploads + description: | + Allows you to upload large files in multiple parts. + navigationGroup: endpoints + sections: + - type: endpoint + key: createUpload + path: create + - type: endpoint + key: addUploadPart + path: add-part + - type: endpoint + key: completeUpload + path: complete + - type: endpoint + key: cancelUpload + path: cancel + - type: object + key: Upload + path: object + - type: object + key: UploadPart + path: part-object + - id: images + title: Images + description: | + Given a prompt and/or an input image, the model will generate a new image. + + Related guide: [Image generation](/docs/guides/images) + navigationGroup: endpoints + sections: + - type: endpoint + key: createImage + path: create + - type: endpoint + key: createImageEdit + path: createEdit + - type: endpoint + key: createImageVariation + path: createVariation + - type: object + key: Image + path: object + - id: models + title: Models + description: | + List and describe the various models available in the API. You can refer to the [Models](/docs/models) documentation to understand what models are available and the differences between them. + navigationGroup: endpoints + sections: + - type: endpoint + key: listModels + path: list + - type: endpoint + key: retrieveModel + path: retrieve + - type: endpoint + key: deleteModel + path: delete + - type: object + key: Model + path: object + - id: moderations + title: Moderations + description: | + Given some input text, outputs if the model classifies it as potentially harmful across several categories. + + Related guide: [Moderations](/docs/guides/moderation) + navigationGroup: endpoints + sections: + - type: endpoint + key: createModeration + path: create + - type: object + key: CreateModerationResponse + path: object + + - id: assistants + title: Assistants + beta: true + description: | + Build assistants that can call models and use tools to perform tasks. + + [Get started with the Assistants API](/docs/assistants) + navigationGroup: assistants + sections: + - type: endpoint + key: createAssistant + path: createAssistant + - type: endpoint + key: listAssistants + path: listAssistants + - type: endpoint + key: getAssistant + path: getAssistant + - type: endpoint + key: modifyAssistant + path: modifyAssistant + - type: endpoint + key: deleteAssistant + path: deleteAssistant + - type: object + key: AssistantObject + path: object + - id: threads + title: Threads + beta: true + description: | + Create threads that assistants can interact with. + + Related guide: [Assistants](/docs/assistants/overview) + navigationGroup: assistants + sections: + - type: endpoint + key: createThread + path: createThread + - type: endpoint + key: getThread + path: getThread + - type: endpoint + key: modifyThread + path: modifyThread + - type: endpoint + key: deleteThread + path: deleteThread + - type: object + key: ThreadObject + path: object + - id: messages + title: Messages + beta: true + description: | + Create messages within threads + + Related guide: [Assistants](/docs/assistants/overview) + navigationGroup: assistants + sections: + - type: endpoint + key: createMessage + path: createMessage + - type: endpoint + key: listMessages + path: listMessages + - type: endpoint + key: getMessage + path: getMessage + - type: endpoint + key: modifyMessage + path: modifyMessage + - type: endpoint + key: deleteMessage + path: deleteMessage + - type: object + key: MessageObject + path: object + - id: runs + title: Runs + beta: true + description: | + Represents an execution run on a thread. + + Related guide: [Assistants](/docs/assistants/overview) + navigationGroup: assistants + sections: + - type: endpoint + key: createRun + path: createRun + - type: endpoint + key: createThreadAndRun + path: createThreadAndRun + - type: endpoint + key: listRuns + path: listRuns + - type: endpoint + key: getRun + path: getRun + - type: endpoint + key: modifyRun + path: modifyRun + - type: endpoint + key: submitToolOuputsToRun + path: submitToolOutputs + - type: endpoint + key: cancelRun + path: cancelRun + - type: object + key: RunObject + path: object + - id: run-steps + title: Run Steps + beta: true + description: | + Represents the steps (model and tool calls) taken during the run. + + Related guide: [Assistants](/docs/assistants/overview) + navigationGroup: assistants + sections: + - type: endpoint + key: listRunSteps + path: listRunSteps + - type: endpoint + key: getRunStep + path: getRunStep + - type: object + key: RunStepObject + path: step-object + - id: vector-stores + title: Vector Stores + beta: true + description: | + Vector stores are used to store files for use by the `file_search` tool. + + Related guide: [File Search](/docs/assistants/tools/file-search) + navigationGroup: assistants + sections: + - type: endpoint + key: createVectorStore + path: create + - type: endpoint + key: listVectorStores + path: list + - type: endpoint + key: getVectorStore + path: retrieve + - type: endpoint + key: modifyVectorStore + path: modify + - type: endpoint + key: deleteVectorStore + path: delete + - type: object + key: VectorStoreObject + path: object + - id: vector-stores-files + title: Vector Store Files + beta: true + description: | + Vector store files represent files inside a vector store. + + Related guide: [File Search](/docs/assistants/tools/file-search) + navigationGroup: assistants + sections: + - type: endpoint + key: createVectorStoreFile + path: createFile + - type: endpoint + key: listVectorStoreFiles + path: listFiles + - type: endpoint + key: getVectorStoreFile + path: getFile + - type: endpoint + key: deleteVectorStoreFile + path: deleteFile + - type: object + key: VectorStoreFileObject + path: file-object + - id: vector-stores-file-batches + title: Vector Store File Batches + beta: true + description: | + Vector store file batches represent operations to add multiple files to a vector store. + + Related guide: [File Search](/docs/assistants/tools/file-search) + navigationGroup: assistants + sections: + - type: endpoint + key: createVectorStoreFileBatch + path: createBatch + - type: endpoint + key: getVectorStoreFileBatch + path: getBatch + - type: endpoint + key: cancelVectorStoreFileBatch + path: cancelBatch + - type: endpoint + key: listFilesInVectorStoreBatch + path: listBatchFiles + - type: object + key: VectorStoreFileBatchObject + path: batch-object + - id: assistants-streaming + title: Streaming + beta: true + description: | + Stream the result of executing a Run or resuming a Run after submitting tool outputs. + + You can stream events from the [Create Thread and Run](/docs/api-reference/runs/createThreadAndRun), + [Create Run](/docs/api-reference/runs/createRun), and [Submit Tool Outputs](/docs/api-reference/runs/submitToolOutputs) + endpoints by passing `"stream": true`. The response will be a [Server-Sent events](https://html.spec.whatwg.org/multipage/server-sent-events.html#server-sent-events) stream. + + Our Node and Python SDKs provide helpful utilities to make streaming easy. Reference the + [Assistants API quickstart](/docs/assistants/overview) to learn more. + navigationGroup: assistants + sections: + - type: object + key: MessageDeltaObject + path: message-delta-object + - type: object + key: RunStepDeltaObject + path: run-step-delta-object + - type: object + key: AssistantStreamEvent + path: events + + - id: administration + title: Overview + description: | + Programmatically manage your organization. + + The Audit Logs endpoint provides a log of all actions taken in the + organization for security and monitoring purposes. + + To access these endpoints please generate an Admin API Key through the [API Platform Organization overview](/organization/admin-keys). Admin API keys cannot be used for non-administration endpoints. + + For best practices on setting up your organization, please refer to this [guide](/docs/guides/production-best-practices/setting-up-your-organization) + navigationGroup: administration + + - id: invite + title: Invites + description: Invite and manage invitations for an organization. Invited users are automatically added to the Default project. + navigationGroup: administration + sections: + - type: endpoint + key: list-invites + path: list + - type: endpoint + key: inviteUser + path: create + - type: endpoint + key: retrieve-invite + path: retrieve + - type: endpoint + key: delete-invite + path: delete + - type: object + key: Invite + path: object + + - id: users + title: Users + description: | + Manage users and their role in an organization. Users will be automatically added to the Default project. + navigationGroup: administration + sections: + - type: endpoint + key: list-users + path: list + - type: endpoint + key: modify-user + path: modify + - type: endpoint + key: retrieve-user + path: retrieve + - type: endpoint + key: delete-user + path: delete + - type: object + key: User + path: object + + - id: projects + title: Projects + description: | + Manage the projects within an orgnanization includes creation, updating, and archiving or projects. + The Default project cannot be modified or archived. + navigationGroup: administration + sections: + - type: endpoint + key: list-projects + path: list + - type: endpoint + key: create-project + path: create + - type: endpoint + key: retrieve-project + path: retrieve + - type: endpoint + key: modify-project + path: modify + - type: endpoint + key: archive-project + path: archive + - type: object + key: Project + path: object + + - id: project-users + title: Project Users + description: | + Manage users within a project, including adding, updating roles, and removing users. + Users cannot be removed from the Default project, unless they are being removed from the organization. + navigationGroup: administration + sections: + - type: endpoint + key: list-project-users + path: list + - type: endpoint + key: create-project-user + path: creeate + - type: endpoint + key: retrieve-project-user + path: retrieve + - type: endpoint + key: modify-project-user + path: modify + - type: endpoint + key: delete-project-user + path: delete + - type: object + key: ProjectUser + path: object + + - id: project-service-accounts + title: Project Service Accounts + description: | + Manage service accounts within a project. A service account is a bot user that is not associated with a user. + If a user leaves an organization, their keys and membership in projects will no longer work. Service accounts + do not have this limitation. However, service accounts can also be deleted from a project. + navigationGroup: administration + sections: + - type: endpoint + key: list-project-service-accounts + path: list + - type: endpoint + key: create-project-service-account + path: create + - type: endpoint + key: retrieve-project-service-account + path: retrieve + - type: endpoint + key: delete-project-service-account + path: delete + - type: object + key: ProjectServiceAccount + path: object + + - id: project-api-keys + title: Project API Keys + description: | + Manage API keys for a given project. Supports listing and deleting keys for users. + This API does not allow issuing keys for users, as users need to authorize themselves to generate keys. + navigationGroup: administration + sections: + - type: endpoint + key: list-project-api-keys + path: list + - type: endpoint + key: retrieve-project-api-key + path: retrieve + - type: endpoint + key: delete-project-api-key + path: delete + - type: object + key: ProjectApiKey + path: object + + - id: audit-logs + title: Audit Logs + description: | + Logs of user actions and configuration changes within this organization. + + To log events, you must activate logging in the [Organization Settings](/settings/organization/general). + Once activated, for security reasons, logging cannot be deactivated. + navigationGroup: administration + sections: + - type: endpoint + key: list-audit-logs + path: list + - type: object + key: AuditLog + path: object + + - id: completions + title: Completions + legacy: true + navigationGroup: legacy + description: | + Given a prompt, the model will return one or more predicted completions along with the probabilities of alternative tokens at each position. Most developer should use our [Chat Completions API](/docs/guides/text-generation/text-generation-models) to leverage our best and newest models. + sections: + - type: endpoint + key: createCompletion + path: create + - type: object + key: CreateCompletionResponse + path: object diff --git a/scripts/requirements.txt b/scripts/requirements.txt new file mode 100644 index 00000000..18993b10 --- /dev/null +++ b/scripts/requirements.txt @@ -0,0 +1,5 @@ +numpy==1.24.4 +typer==0.9.0 +lorem-text==2.1 +transformers==4.36.0 +chardet==5.2.0 \ No newline at end of file diff --git a/scripts/throughput_benchmarks.py b/scripts/throughput_benchmarks.py new file mode 100644 index 00000000..06888dd7 --- /dev/null +++ b/scripts/throughput_benchmarks.py @@ -0,0 +1,507 @@ +import csv +import json +import os +import queue +import random +import threading +import time +import traceback +from dataclasses import dataclass +from enum import Enum +from typing import List, Optional + +import numpy as np +import requests +import typer +from lorem_text import lorem +from transformers import AutoTokenizer + +AUTH_USER_ID = os.getenv("AUTH_USER_ID") +GATEWAY_URL = os.getenv("GATEWAY_URL") +app = typer.Typer(name="throughput-benchmarks", add_completion=False) + +MAX_CONTEXT_WINDOW = 100000 + + +@dataclass +class BenchmarkConfig: + def __init__(self, input_token_count, output_token_count_mean): + self.input_token_count = input_token_count + self.output_token_count_mean = output_token_count_mean + # Here we assume 3x standard deviation is enough to cover the range of output token counts. + # Also assume 3x stddev is rougly half of the mean. + self.output_token_count_std = output_token_count_mean / 6.0 + + def __repr__(self) -> str: + return f"BenchmarkConfig(input_token_count={self.input_token_count}, output_token_count_mean={self.output_token_count_mean}, output_token_count_std={self.output_token_count_std})" + + +HF_MODEL_MAPPING = { + "llama-2-7b": "meta-llama/Llama-2-7b-hf", + "llama-2-13b": "meta-llama/Llama-2-13b-hf", +} + + +class InferenceFramework(Enum): + TEXT_GENERATION_INFERENCE = "tgi" + VLLM = "vllm" + LIGHTLLM = "lightllm" + TENSORRT_LLM = "tensorrt-llm" + + @classmethod + def from_value(cls, value): + for member in cls: + if member.value == value: + return member + raise ValueError(f"No member with value {value} in {cls.__name__}") + + +def send_request(url, request, user=None): + start = time.time() + response = requests.post( + url, + json=request, + auth=(user, ""), + stream=True, + ) + first_line = True + inter_token_latencies = [] + last_token_time = None + payload_json: dict = {} + num_completion_tokens = 0 # We calculate this value manually since tensorrt llm doesn't give it + for byte_payload in response.iter_lines(): + # Skip line + if byte_payload == b"\n" or byte_payload == b"": + continue + + token_time = time.time() + if first_line: + time_to_first_token = token_time - start + last_token_time = token_time + first_line = False + else: + inter_token_latencies.append(token_time - last_token_time) + last_token_time = token_time + + payload = byte_payload.decode("utf-8") + + # Event data + if payload.startswith("data:"): + payload_data = payload.lstrip("data:").rstrip("/n") + payload_json = json.loads(payload_data) + num_completion_tokens += 1 + + return { + "payload": payload_json, + "time_to_first_token": time_to_first_token, + "total_time": time.time() - start, + "inter_token_latencies": inter_token_latencies, + "num_completion_tokens": num_completion_tokens, + } + + +def pull_and_send_request_from_queue( + model: str, + request_queue: queue.Queue, + result_queue: queue.Queue, + use_localhost: bool, + framework: InferenceFramework, + local_port: int = 5005, +): + while not request_queue.empty(): + request = request_queue.get() + if use_localhost: + if framework == InferenceFramework.VLLM: + response = send_request(f"http://localhost:{local_port}/stream", request) + response["num_completion_tokens"] = response["payload"][ + "count_output_tokens" + ] # vLLM gives us completion token count, use that. + elif framework == InferenceFramework.TENSORRT_LLM: + response = send_request( + f"http://localhost:{local_port}/v2/models/ensemble/generate_stream", request + ) + else: + raise NotImplementedError() + else: + response = send_request( + f"{GATEWAY_URL}/v1/llm/completions-stream?model_endpoint_name={model}", + request, + AUTH_USER_ID, + ) + response["num_completion_tokens"] = response["payload"]["output"][ + "num_completion_tokens" + ] + + result_queue.put(response) + + +def generate_request( + framework: InferenceFramework, prompt: str, output_token_count: int, localhost: bool +): + temperature = 0.0 + + if not localhost: + return {"prompt": prompt, "max_new_tokens": output_token_count, "temperature": temperature} + + if framework == InferenceFramework.TEXT_GENERATION_INFERENCE: + return { + "parameters": { + "do_sample": False, + "max_new_tokens": output_token_count, + "details": False, + }, + "inputs": prompt, + } + elif framework == InferenceFramework.VLLM: + return { + "prompt": prompt, + "max_tokens": output_token_count, + "temperature": temperature, + "stream": True, + } + elif framework == InferenceFramework.LIGHTLLM: + return { + "parameters": { + "do_sample": False, + "max_new_tokens": output_token_count, + }, + "inputs": prompt, + } + elif framework == InferenceFramework.TENSORRT_LLM: + return { + "max_tokens": output_token_count, + "text_input": prompt, + "bad_words": "", + "stop_words": "", + "parameters": { + "temperature": temperature, + "stream": True, + }, + } + else: + raise NotImplementedError() + + +def send_requests( + model: str, + prompt: str, + output_token_counts: List[int], + use_localhost: bool, + concurrency: int, + framework: InferenceFramework, + local_port: int = 5005, + prompts_list_override: Optional[List] = None, +): + thread_results: queue.Queue = queue.Queue() + requests_queue: queue.Queue = queue.Queue() + for i, output_token_count in enumerate(output_token_counts): + if prompts_list_override is not None: + new_prompt = prompts_list_override[i % len(prompts_list_override)] + else: + new_prompt = prompt + request = generate_request(framework, new_prompt, output_token_count, use_localhost) + requests_queue.put(request) + threads = [] + for i in range(concurrency): + thread = threading.Thread( + target=pull_and_send_request_from_queue, + args=( + model, + requests_queue, + thread_results, + use_localhost, + framework, + local_port, + ), + ) + thread.start() + threads.append(thread) + + for thread in threads: + thread.join() + + results = [] + while not thread_results.empty(): + results.append(thread_results.get()) + + return results + + +def generate_prompt(num, hf_model): + random.seed(1) + text = lorem.words(num // 2) # Roughly 2 tokens per lorem word + tokenizer = AutoTokenizer.from_pretrained(hf_model) + return tokenizer.decode(tokenizer.encode(text)[: num - 2]) + + +def generate_output_token_counts(mean, std, num, input_token_count): + output = np.random.normal(mean, std, num).astype(int).tolist() + + for i in range(len(output)): + output[i] = min(output[i], MAX_CONTEXT_WINDOW - input_token_count) + return output + + +def generate_output_token_counts_from_existing( + distribution: List[int], num: int, input_token_count: int +): + assert len(distribution) > 0, "Can't have a distribution with 0 tokens" + output = [] + # Sample without replacement so that we don't have as much variance + for _ in range(num // len(distribution)): + random.shuffle(distribution) + output.extend(distribution) + random.shuffle(distribution) + output.extend(distribution[: num % len(distribution)]) + assert len(output) == num + + for i in range(len(output)): + output[i] = min(output[i], MAX_CONTEXT_WINDOW - input_token_count) + return output + + +def read_data_from_json_file(fpath: str): + # Assumes the distribution is some json-formatted string that represents a list + try: + with open(fpath, "r") as fin: + return json.load(fin) + except FileNotFoundError: + print("File not found. Exiting.") + raise + + +def run_benchmark( + model: str, + framework: InferenceFramework, + hf_model: str, + config: BenchmarkConfig, + num_trials: int, + use_localhost: bool, + concurrency: int, + verbose: bool, + local_port: int, + response_token_count_distribution: Optional[List] = None, + prompts_list_override: Optional[List] = None, +): + prompt = generate_prompt(config.input_token_count, hf_model) + + prompt_num_tokens = config.input_token_count + + if response_token_count_distribution is not None: + output_token_counts = generate_output_token_counts_from_existing( + response_token_count_distribution, num_trials, config.input_token_count + ) + else: + output_token_counts = generate_output_token_counts( + config.output_token_count_mean, + config.output_token_count_std, + num_trials, + config.input_token_count, + ) + + start = time.time() + results = send_requests( + model, + prompt, + output_token_counts, + use_localhost, + concurrency, + framework, + local_port=local_port, + prompts_list_override=prompts_list_override, + ) + end = time.time() + elapsed = end - start + results = [result for result in results if result is not None] + + num_sampled_tokens = sum([result["num_completion_tokens"] for result in results]) + num_prompt_tokens = prompt_num_tokens * len(results) + n = len(results) + time_to_process_prompt = [] + time_per_completion = [] + time_to_first_token = [] + inter_token_latency = [] # one value per request, average inter-token latency in the request + total_request_time = [] + all_inter_token_latencies = [] # one value per token (except the first generated token) + for result in results: + avg_time_per_token = (result["total_time"] - result["time_to_first_token"]) / ( + max(1, result["num_completion_tokens"] - 1) + ) + time_to_first_token.append(result["time_to_first_token"]) + time_to_process_prompt.append(result["time_to_first_token"] - avg_time_per_token) + time_per_completion.append(result["total_time"] - time_to_process_prompt[-1]) + inter_token_latency.append(avg_time_per_token) + total_request_time.append(result["total_time"]) + all_inter_token_latencies.extend(result["inter_token_latencies"]) + + total_num_tokens = num_sampled_tokens + num_prompt_tokens + avg_prefill_time = sum(time_to_process_prompt) / n + avg_completion_time = sum(time_per_completion) / n + p50_request_time = np.percentile(total_request_time, 50) + p90_request_time = np.percentile(total_request_time, 90) + p95_request_time = np.percentile(total_request_time, 95) + p99_request_time = np.percentile(total_request_time, 99) + p50_inter_token_latency = np.percentile(all_inter_token_latencies, 50) + p90_inter_token_latency = np.percentile(all_inter_token_latencies, 90) + p95_inter_token_latency = np.percentile(all_inter_token_latencies, 95) + p99_inter_token_latency = np.percentile(all_inter_token_latencies, 99) + p999_inter_token_latency = np.percentile(all_inter_token_latencies, 99.9) + p50_time_to_first_token = np.percentile(time_to_first_token, 50) + p90_time_to_first_token = np.percentile(time_to_first_token, 90) + p95_time_to_first_token = np.percentile(time_to_first_token, 95) + p99_time_to_first_token = np.percentile(time_to_first_token, 99) + + statistics = { + "concurrency": concurrency, + "avg_prompt_throughput": num_prompt_tokens + / (elapsed * avg_prefill_time / (avg_prefill_time + avg_completion_time)), + "avg_time_to_first_token": sum(time_to_first_token) / n, + "p50_time_to_first_token": p50_time_to_first_token, + "p90_time_to_first_token": p90_time_to_first_token, + "p95_time_to_first_token": p95_time_to_first_token, + "p99_time_to_first_token": p99_time_to_first_token, + "avg_sampling_throughput": num_sampled_tokens + / (elapsed * avg_completion_time / (avg_prefill_time + avg_completion_time)), + "avg_total_throughput": total_num_tokens / elapsed, + "avg_per_session_sampling_throughput": num_sampled_tokens + / (elapsed * avg_completion_time / (avg_prefill_time + avg_completion_time)) + / concurrency, + "avg_request_throughput": n / elapsed, + "avg_inter_token_latency": sum(inter_token_latency) / n, + "p50_inter_token_latency": p50_inter_token_latency, + "p90_inter_token_latency": p90_inter_token_latency, + "p95_inter_token_latency": p95_inter_token_latency, + "p99_inter_token_latency": p99_inter_token_latency, + "p99.9_inter_token_latency": p999_inter_token_latency, + "num_prompt_tokens": prompt_num_tokens, + "avg_num_sampled_tokens": num_sampled_tokens / n, + "elapsed_time": elapsed, + "avg_prefill_time": avg_prefill_time, + "avg_completion_time": avg_completion_time, + "p50_request_time": p50_request_time, + "p90_request_time": p90_request_time, + "p95_request_time": p95_request_time, + "p99_request_time": p99_request_time, + "num_requests": num_trials, + "num_successful_requests": n, + "total_num_tokens": total_num_tokens, + "total_num_sampled_tokens": num_sampled_tokens, + } + if verbose: + print(f"Statistics: {statistics}") + + # Sleep for 1 seconds between each benchmark. + time.sleep(1) + + return statistics + + +@app.command() +def run_benchmarks( + model: str, + framework: str, + input_token_count: int, + output_token_count_mean: int, + num_trials: int = 50, + output_file: Optional[str] = None, + use_localhost: bool = False, + concurrency: int = 1, + verbose: bool = False, + hf_model: Optional[str] = None, + local_port: int = 5005, + response_token_count_distribution_file: Optional[str] = None, + prompts_list_override_file: Optional[str] = None, +): + """Run benchmarks.""" + all_statistics = [] + config = BenchmarkConfig(input_token_count, output_token_count_mean) + + response_token_count_distribution = None + if response_token_count_distribution_file is not None: + response_token_count_distribution = read_data_from_json_file( + response_token_count_distribution_file + ) + prompts_list_override = None + if prompts_list_override_file is not None: + prompts_list_override = read_data_from_json_file(prompts_list_override_file) + + try: + if verbose: + print(f"Running benchmark for config {config}") + if hf_model is None: + if model not in HF_MODEL_MAPPING: + raise ValueError( + f"--hf-model must be specified for model {model} since it's not in default mapping." + ) + hf_model = HF_MODEL_MAPPING[model] + statistics = run_benchmark( + model, + InferenceFramework.from_value(framework), + hf_model, + config, + num_trials, + use_localhost, + concurrency, + verbose, + local_port, + response_token_count_distribution, + prompts_list_override, + ) + all_statistics.append(statistics) + except Exception: + traceback.print_exc() + + if output_file is not None: + header = all_statistics[0].keys() + import os + + if not os.path.exists(output_file): + with open(output_file, "w") as csvfile: + print("creating the data in csv") + csv_writer = csv.DictWriter(csvfile, fieldnames=header) + csv_writer.writeheader() + csv_writer.writerows(all_statistics) + else: + with open(output_file, "a") as csvfile: + csv_writer = csv.DictWriter(csvfile, fieldnames=header) + csv_writer.writerows(all_statistics) + + +@app.command() +def run_benchmarks_concurrency_range( + model: str, + framework: str, + input_token_count: int, + output_token_count_mean: int, + num_trials_per_concurrency: int = 5, + output_file: Optional[str] = None, + use_localhost: bool = False, + concurrency_min: int = 1, + concurrency_max: int = 1, + concurrency_step: int = 1, + verbose: bool = False, + hf_model: Optional[str] = None, + local_port: int = 5005, + response_token_count_distribution_file: Optional[str] = None, + prompts_list_override_file: Optional[str] = None, +): + for concurrency in range(concurrency_min, concurrency_max + 1, concurrency_step): + run_benchmarks( + model, + framework, + input_token_count, + output_token_count_mean, + num_trials_per_concurrency * concurrency, + output_file, + use_localhost, + concurrency, + verbose, + hf_model, + local_port, + response_token_count_distribution_file, + prompts_list_override_file, + ) + + +if __name__ == "__main__": + app() diff --git a/server/Dockerfile.openapi b/server/Dockerfile.openapi deleted file mode 100644 index f0c892de..00000000 --- a/server/Dockerfile.openapi +++ /dev/null @@ -1,6 +0,0 @@ -FROM openapitools/openapi-generator-cli:v6.4.0 as openapi -RUN apt-get update && apt-get install -y npm && rm -rf /var/lib/apt/lists/* -RUN npm install @openapitools/openapi-generator-cli -g -RUN openapi-generator-cli version-manager set 6.4.0 -WORKDIR /local -ENTRYPOINT ["openapi-generator-cli"] diff --git a/server/Makefile b/server/Makefile deleted file mode 100644 index c673b1ad..00000000 --- a/server/Makefile +++ /dev/null @@ -1,43 +0,0 @@ -install: - pip install -r requirements.txt - pip install -r requirements_override.txt - pip install -e . - -install-test: - pip install -r requirements-test.txt - -install-dev: - pip install -r ../requirements-dev.txt - -install-docs: - pip install -r ../requirements-docs.txt - pip install -e ../clients/python/ - -requirements: install-dev - pip-compile --allow-unsafe --no-emit-index-url --no-emit-trusted-host --output-file=requirements.txt requirements.in - -install-all: install install-test install-dev install-docs - -test: - WORKSPACE=.. pytest - -autogen-templates: - pushd charts && \ - helm template llm-engine llm-engine -f llm-engine/values_circleci.yaml \ - -s templates/service_template_config_map.yaml \ - --set message='# THIS FILE IS AUTOGENERATED USING `just autogen-templates`. PLEASE EDIT THE GOTEMPLATE FILE IN THE HELM CHART!!!' \ - > ../llm_engine/infra/gateways/resources/templates/service_template_config_map_circleci.yaml \ - && popd - -build: - docker compose build llm-engine - -dev: - # TODO: add env variables to make this work. - docker compose up llm-engine-gateway-dev llm-engine-service-builder-dev - -build-docs: - mkdocs build - -dev-docs: - mkdocs serve diff --git a/server/docker-compose.yml b/server/docker-compose.yml deleted file mode 100644 index ca172d2a..00000000 --- a/server/docker-compose.yml +++ /dev/null @@ -1,155 +0,0 @@ -version: "3.8" - -services: - llm-engine: - build: - context: .. - dockerfile: server/Dockerfile - target: llm-engine - image: "${ECR_HOST:-local}/llm-engine:${GIT_SHA:-latest}" - llm-engine-gateway-dev: - build: - context: .. - dockerfile: server/Dockerfile - target: llm-engine - command: - - python - - -m - - llm_engine_server.entrypoints.start_fastapi_server - - --port=5001 - - --debug - - --num-workers=1 - environment: - - AWS_PROFILE - - SQS_PROFILE - - KUBECONFIG=/workspace/.kube/kubeconfig:/workspace/.kube/config - - SERVICE_IDENTIFIER - - AWS_CONFIG_FILE=/creds/.aws/config - - AWS_SHARED_CREDENTIALS_FILE=/creds/.aws/credentials - - CELERY_ELASTICACHE_ENABLED=true - - "GIT_TAG=${GIT_SHA}" - - DD_ENV=training - - "DEPLOY_SERVICE_CONFIG_PATH=/workspace/server/service_configs/service_config_${ENV}.yaml" - - "LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH=/workspace/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_${ENV}.yaml" - - "ML_INFRA_SERVICES_CONFIG_PATH=/workspace/server/llm_engine_server/core/configs/${ENV}.yaml" - - "DB_SECRET_NAME=${DB_SECRET_NAME:-}" - - "ML_INFRA_DATABASE_URL=${ML_INFRA_DATABASE_URL:-}" - - "USE_REDIS_LOCALHOST=${USE_REDIS_LOCALHOST:-}" - - "SKIP_AUTH=${SKIP_AUTH:-}" - - "CIRCLECI=${CIRCLECI:-}" - - "LOCAL=${LOCAL:-false}" - network_mode: host - ports: - - 5001:5001 - stdin_open: true - tty: true - volumes: - - "${HOME}/.kube:/workspace/.kube" - - "${HOME}/.minikube:/workspace/.minikube" - - "${HOME}/.minikube:/home/circleci/.minikube" - - "${HOME}/.aws-mountable:/creds/.aws" - - "../llm_engine:/workspace/llm_engine" - llm-engine-service-builder-dev: - build: - context: .. - dockerfile: server/Dockerfile - target: llm-engine - command: - - celery - - --app=llm_engine_server.service_builder - - worker - - --loglevel=INFO - - --concurrency=4 - - "--queues=${QUEUE}" - environment: - - AWS_PROFILE - - SQS_PROFILE - - ECR_READ_AWS_PROFILE - - DB_SECRET_AWS_PROFILE - - "S3_BUCKET=${S3_BUCKET:-scale-ml}" - - "DEPLOY_SERVICE_CONFIG_PATH=/workspace/server/service_configs/service_config_${ENV}.yaml" - - "LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH=/workspace/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_${ENV}.yaml" - - "ML_INFRA_SERVICES_CONFIG_PATH=/workspace/ml_infra_core/ml_infra_services/ml_infra_services/configs/${ENV}.yaml" - - "GIT_TAG=${GIT_SHA}" - - DD_ENV=training - - SERVICE_IDENTIFIER - - KUBECONFIG=/workspace/.kube/kubeconfig:/workspace/.kube/config - - AWS_CONFIG_FILE=/creds/.aws/config - - AWS_SHARED_CREDENTIALS_FILE=/creds/.aws/credentials - - CELERY_ELASTICACHE_ENABLED=true - - "KANIKO_TEMPLATE=${KANIKO_TEMPLATE:-kaniko_template.yaml}" - - "DB_SECRET_NAME=${DB_SECRET_NAME:-}" - - "ML_INFRA_DATABASE_URL=${ML_INFRA_DATABASE_URL:-}" - - "USE_REDIS_LOCALHOST=${USE_REDIS_LOCALHOST:-}" - - "SKIP_AUTH=${SKIP_AUTH:-}" - - "CIRCLECI=${CIRCLECI:-}" - - "LOCAL=${LOCAL:-false}" - network_mode: host - stdin_open: true - tty: true - volumes: - - "${HOME}/.kube:/workspace/.kube" - - "${HOME}/.minikube:/workspace/.minikube" - - "${HOME}/.minikube:/home/circleci/.minikube" - - "${HOME}/.aws-mountable:/creds/.aws" - - "../llm_engine:/workspace/llm_engine" - llm-engine-bash: - build: - context: .. - dockerfile: server/Dockerfile - target: llm-engine - command: - - /bin/bash - - -c - - "'${BASH_COMMAND:-/bin/bash}'" - environment: - - AWS_PROFILE - - SQS_PROFILE - - ECR_READ_AWS_PROFILE - - DB_SECRET_AWS_PROFILE - - "DEPLOY_SERVICE_CONFIG_PATH=/workspace/server/service_configs/service_config_${ENV}.yaml" - - "LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH=/workspace/server/llm_engine_server/infra/gateways/resources/templates/service_template_config_map_${ENV}.yaml" - - "ML_INFRA_SERVICES_CONFIG_PATH=/workspace/ml_infra_core/ml_infra_services/ml_infra_services/configs/${ENV}.yaml" - - "GIT_TAG=${GIT_SHA}" - - DD_ENV=training - - SERVICE_IDENTIFIER - - KUBECONFIG=/workspace/.kube/kubeconfig:/workspace/.kube/config - - AWS_CONFIG_FILE=/creds/.aws/config - - AWS_SHARED_CREDENTIALS_FILE=/creds/.aws/credentials - - CELERY_ELASTICACHE_ENABLED=true - - "DB_SECRET_NAME=${DB_SECRET_NAME:-}" - - "ML_INFRA_DATABASE_URL=${ML_INFRA_DATABASE_URL:-}" - - "USE_REDIS_LOCALHOST=${USE_REDIS_LOCALHOST:-}" - - "CIRCLECI=${CIRCLECI:-}" - - "LOCAL=${LOCAL:-false}" - network_mode: host - ports: - - 5002:5000 - volumes: - - "${HOME}/.kube:/workspace/.kube" - - "${HOME}/.minikube:/workspace/.minikube" - - "${HOME}/.minikube:/home/circleci/.minikube" - - "${HOME}/.aws-mountable:/creds/.aws" - - "../llm_engine:/workspace/llm_engine" - db: - image: "cimg/postgres:12.8-postgis" - ports: - - 5432:5432 - environment: - - POSTGRES_USER=ml_infra_test - - POSTGRES_DB=ml_infra_test - - POSTGRES_PASSWORD=ml_infra_test - redis: - image: redis - ports: - - 6379:6379 - openapi-generator-cli: - image: "${ECR_HOST:-local}/ml_infra_core/openapi:${GIT_SHA:-latest}" - build: - context: .. - dockerfile: ml_infra_core/Dockerfile.openapi - target: base - volumes: - - "../llm_engine/clients:/local" - command: - - generate diff --git a/server/llm_engine_server/api/app.py b/server/llm_engine_server/api/app.py deleted file mode 100644 index 72703909..00000000 --- a/server/llm_engine_server/api/app.py +++ /dev/null @@ -1,36 +0,0 @@ -from fastapi import FastAPI, Response -from llm_engine_server.api.batch_jobs_v1 import batch_job_router_v1 -from llm_engine_server.api.dependencies import get_or_create_aioredis_pool -from llm_engine_server.api.docker_image_batch_job_bundles_v1 import ( - docker_image_batch_job_bundle_router_v1, -) -from llm_engine_server.api.llms_v1 import llm_router_v1 -from llm_engine_server.api.model_bundles_v1 import model_bundle_router_v1 -from llm_engine_server.api.model_bundles_v2 import model_bundle_router_v2 -from llm_engine_server.api.model_endpoints_docs_v1 import model_endpoints_docs_router_v1 -from llm_engine_server.api.model_endpoints_v1 import model_endpoint_router_v1 -from llm_engine_server.api.tasks_v1 import inference_task_router_v1 - -app = FastAPI(title="llm_engine", version="1.0.0", redoc_url="/api") - -app.include_router(batch_job_router_v1) -app.include_router(inference_task_router_v1) -app.include_router(model_bundle_router_v1) -app.include_router(model_bundle_router_v2) -app.include_router(model_endpoint_router_v1) -app.include_router(model_endpoints_docs_router_v1) -app.include_router(docker_image_batch_job_bundle_router_v1) -app.include_router(llm_router_v1) - - -@app.on_event("startup") -def load_redis(): - get_or_create_aioredis_pool() - - -@app.get("/healthcheck") -@app.get("/healthz") -@app.get("/readyz") -def healthcheck() -> Response: - """Returns 200 if the app is healthy.""" - return Response(status_code=200) diff --git a/server/llm_engine_server/api/dependencies.py b/server/llm_engine_server/api/dependencies.py deleted file mode 100644 index f00211a8..00000000 --- a/server/llm_engine_server/api/dependencies.py +++ /dev/null @@ -1,280 +0,0 @@ -import asyncio -import os -from dataclasses import dataclass -from typing import Callable, Iterator, Optional - -import aioredis -from fastapi import Depends, HTTPException, status -from fastapi.security import HTTPBasic, HTTPBasicCredentials -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.model_endpoints import BrokerType -from llm_engine_server.common.env_vars import CIRCLECI -from llm_engine_server.core.auth.authentication_repository import AuthenticationRepository, User -from llm_engine_server.core.auth.fake_authentication_repository import FakeAuthenticationRepository -from llm_engine_server.db.base import SessionAsync, SessionReadOnlyAsync -from llm_engine_server.domain.gateways import ( - DockerImageBatchJobGateway, - ModelPrimitiveGateway, - TaskQueueGateway, -) -from llm_engine_server.domain.repositories import ( - DockerImageBatchJobBundleRepository, - DockerRepository, - ModelBundleRepository, -) -from llm_engine_server.domain.services import ( - BatchJobService, - LLMModelEndpointService, - ModelEndpointService, -) -from llm_engine_server.infra.gateways import ( - CeleryTaskQueueGateway, - FakeMonitoringMetricsGateway, - LiveAsyncModelEndpointInferenceGateway, - LiveBatchJobOrchestrationGateway, - LiveBatchJobProgressGateway, - LiveDockerImageBatchJobGateway, - LiveModelEndpointInfraGateway, - LiveModelEndpointsSchemaGateway, - LiveStreamingModelEndpointInferenceGateway, - LiveSyncModelEndpointInferenceGateway, - ModelEndpointInfraGateway, - S3FilesystemGateway, -) -from llm_engine_server.infra.gateways.fake_model_primitive_gateway import FakeModelPrimitiveGateway -from llm_engine_server.infra.gateways.resources.endpoint_resource_gateway import ( - EndpointResourceGateway, -) -from llm_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( - FakeSQSEndpointResourceDelegate, -) -from llm_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( - LiveEndpointResourceGateway, -) -from llm_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( - LiveSQSEndpointResourceDelegate, -) -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, -) -from llm_engine_server.infra.repositories import ( - DbBatchJobRecordRepository, - DbDockerImageBatchJobBundleRepository, - DbModelBundleRepository, - DbModelEndpointRecordRepository, - ECRDockerRepository, - RedisModelEndpointCacheRepository, - S3FileLLMFineTuningJobRepository, -) -from llm_engine_server.infra.services import ( - DockerImageBatchJobLLMFineTuningService, - LiveBatchJobService, - LiveModelEndpointService, -) -from llm_engine_server.infra.services.live_llm_model_endpoint_service import ( - LiveLLMModelEndpointService, -) -from sqlalchemy.ext.asyncio import AsyncSession, async_scoped_session - -AUTH = HTTPBasic(auto_error=False) - - -@dataclass -class ExternalInterfaces: - """ - Internal object used for aggregating various Gateway and Repository objects for dependency - injection. - """ - - docker_repository: DockerRepository - docker_image_batch_job_bundle_repository: DockerImageBatchJobBundleRepository - model_bundle_repository: ModelBundleRepository - model_endpoint_service: ModelEndpointService - batch_job_service: BatchJobService - llm_model_endpoint_service: LLMModelEndpointService - llm_fine_tuning_service: DockerImageBatchJobLLMFineTuningService - - resource_gateway: EndpointResourceGateway - endpoint_creation_task_queue_gateway: TaskQueueGateway - inference_task_queue_gateway: TaskQueueGateway - model_endpoint_infra_gateway: ModelEndpointInfraGateway - docker_image_batch_job_gateway: DockerImageBatchJobGateway - model_primitive_gateway: ModelPrimitiveGateway - - -def _get_external_interfaces( - read_only: bool, session: Callable[[], AsyncSession] -) -> ExternalInterfaces: - """ - Dependency that returns a ExternalInterfaces object. This allows repositories to share - sessions for the database and redis. - """ - redis_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS) - redis_24h_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.REDIS_24H) - sqs_task_queue_gateway = CeleryTaskQueueGateway(broker_type=BrokerType.SQS) - monitoring_metrics_gateway = FakeMonitoringMetricsGateway() - model_endpoint_record_repo = DbModelEndpointRecordRepository( - monitoring_metrics_gateway=monitoring_metrics_gateway, - session=session, - read_only=read_only, - ) - - sqs_delegate: SQSEndpointResourceDelegate - if CIRCLECI: - sqs_delegate = FakeSQSEndpointResourceDelegate() - else: - sqs_delegate = LiveSQSEndpointResourceDelegate( - sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) - ) - - inference_task_queue_gateway = ( - sqs_task_queue_gateway if not CIRCLECI else redis_24h_task_queue_gateway - ) - resource_gateway = LiveEndpointResourceGateway(sqs_delegate=sqs_delegate) - redis_client = aioredis.Redis(connection_pool=get_or_create_aioredis_pool()) - model_endpoint_cache_repo = RedisModelEndpointCacheRepository( - redis_client=redis_client, - ) - model_endpoint_infra_gateway = LiveModelEndpointInfraGateway( - resource_gateway=resource_gateway, - task_queue_gateway=redis_task_queue_gateway, - ) - async_model_endpoint_inference_gateway = LiveAsyncModelEndpointInferenceGateway( - task_queue_gateway=inference_task_queue_gateway - ) - # In CircleCI, we cannot use asyncio because aiohttp cannot connect to the sync endpoints. - sync_model_endpoint_inference_gateway = LiveSyncModelEndpointInferenceGateway( - use_asyncio=(not CIRCLECI), - ) - streaming_model_endpoint_inference_gateway = LiveStreamingModelEndpointInferenceGateway( - use_asyncio=(not CIRCLECI), - ) - filesystem_gateway = S3FilesystemGateway() - model_endpoints_schema_gateway = LiveModelEndpointsSchemaGateway( - filesystem_gateway=filesystem_gateway - ) - model_endpoint_service = LiveModelEndpointService( - model_endpoint_record_repository=model_endpoint_record_repo, - model_endpoint_infra_gateway=model_endpoint_infra_gateway, - model_endpoint_cache_repository=model_endpoint_cache_repo, - async_model_endpoint_inference_gateway=async_model_endpoint_inference_gateway, - streaming_model_endpoint_inference_gateway=streaming_model_endpoint_inference_gateway, - sync_model_endpoint_inference_gateway=sync_model_endpoint_inference_gateway, - model_endpoints_schema_gateway=model_endpoints_schema_gateway, - ) - llm_model_endpoint_service = LiveLLMModelEndpointService( - model_endpoint_record_repository=model_endpoint_record_repo, - model_endpoint_service=model_endpoint_service, - ) - model_bundle_repository = DbModelBundleRepository(session=session, read_only=read_only) - docker_image_batch_job_bundle_repository = DbDockerImageBatchJobBundleRepository( - session=session, read_only=read_only - ) - batch_job_record_repository = DbBatchJobRecordRepository(session=session, read_only=read_only) - batch_job_orchestration_gateway = LiveBatchJobOrchestrationGateway() - batch_job_progress_gateway = LiveBatchJobProgressGateway(filesystem_gateway=filesystem_gateway) - batch_job_service = LiveBatchJobService( - batch_job_record_repository=batch_job_record_repository, - model_endpoint_service=model_endpoint_service, - batch_job_orchestration_gateway=batch_job_orchestration_gateway, - batch_job_progress_gateway=batch_job_progress_gateway, - ) - - model_primitive_gateway: ModelPrimitiveGateway - model_primitive_gateway = FakeModelPrimitiveGateway() - - docker_image_batch_job_gateway = LiveDockerImageBatchJobGateway() - - llm_fine_tuning_job_repository = S3FileLLMFineTuningJobRepository( - file_path=os.getenv( - "S3_FILE_LLM_FINE_TUNING_JOB_REPOSITORY", - hmi_config.s3_file_llm_fine_tuning_job_repository, - ), - ) - llm_fine_tuning_service = DockerImageBatchJobLLMFineTuningService( - docker_image_batch_job_gateway=docker_image_batch_job_gateway, - docker_image_batch_job_bundle_repo=docker_image_batch_job_bundle_repository, - llm_fine_tuning_job_repository=llm_fine_tuning_job_repository, - ) - - external_interfaces = ExternalInterfaces( - docker_repository=ECRDockerRepository(), - model_bundle_repository=model_bundle_repository, - model_endpoint_service=model_endpoint_service, - llm_model_endpoint_service=llm_model_endpoint_service, - batch_job_service=batch_job_service, - resource_gateway=resource_gateway, - endpoint_creation_task_queue_gateway=redis_task_queue_gateway, - inference_task_queue_gateway=sqs_task_queue_gateway, - model_endpoint_infra_gateway=model_endpoint_infra_gateway, - model_primitive_gateway=model_primitive_gateway, - docker_image_batch_job_bundle_repository=docker_image_batch_job_bundle_repository, - docker_image_batch_job_gateway=docker_image_batch_job_gateway, - llm_fine_tuning_service=llm_fine_tuning_service, - ) - return external_interfaces - - -async def get_external_interfaces(): - try: - session = async_scoped_session(SessionAsync, scopefunc=asyncio.current_task) - yield _get_external_interfaces(read_only=False, session=session) - finally: - pass - - -async def get_external_interfaces_read_only(): - try: - session = async_scoped_session(SessionReadOnlyAsync, scopefunc=asyncio.current_task) - yield _get_external_interfaces(read_only=True, session=session) - finally: - pass - - -def get_auth_repository() -> Iterator[AuthenticationRepository]: - """ - Dependency for an AuthenticationRepository. This implementation returns a Scale-specific repository. - """ - try: - yield FakeAuthenticationRepository() - finally: - pass - - -async def verify_authentication( - credentials: HTTPBasicCredentials = Depends(AUTH), - auth_repo: AuthenticationRepository = Depends(get_auth_repository), -) -> User: - """ - Verifies the authentication headers and returns a (user_id, team_id) auth tuple. Otherwise, - raises a 401. - """ - user_id = credentials.username if credentials is not None else None - if user_id is None: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="No user id was passed in", - headers={"WWW-Authenticate": "Basic"}, - ) - - auth = await auth_repo.get_auth_from_user_id_async(user_id=user_id) - - if not auth: - raise HTTPException( - status_code=status.HTTP_401_UNAUTHORIZED, - detail="Could not authenticate user", - headers={"WWW-Authenticate": "Basic"}, - ) - - return auth - - -_pool: Optional[aioredis.BlockingConnectionPool] = None - - -def get_or_create_aioredis_pool() -> aioredis.ConnectionPool: - global _pool - - if _pool is None: - _pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) - return _pool diff --git a/server/llm_engine_server/api/llms_v1.py b/server/llm_engine_server/api/llms_v1.py deleted file mode 100644 index 1d54ca5e..00000000 --- a/server/llm_engine_server/api/llms_v1.py +++ /dev/null @@ -1,325 +0,0 @@ -"""LLM Model Endpoint routes for the hosted model inference service. -""" -from typing import Optional - -from fastapi import APIRouter, Depends, HTTPException, Query -from llm_engine_server.api.dependencies import ( - ExternalInterfaces, - get_external_interfaces, - get_external_interfaces_read_only, - verify_authentication, -) -from llm_engine_server.common.datadog_utils import add_trace_resource_name -from llm_engine_server.common.dtos.llms import ( - CancelFineTuneJobResponse, - CompletionStreamV1Request, - CompletionStreamV1Response, - CompletionSyncV1Request, - CompletionSyncV1Response, - CreateFineTuneJobRequest, - CreateFineTuneJobResponse, - CreateLLMModelEndpointV1Request, - CreateLLMModelEndpointV1Response, - GetFineTuneJobResponse, - GetLLMModelEndpointV1Response, - ListFineTuneJobResponse, - ListLLMModelEndpointsV1Response, -) -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.common.dtos.tasks import TaskStatus -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( - ObjectAlreadyExistsException, - ObjectHasInvalidValueException, - ObjectNotApprovedException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import ( - EndpointLabelsException, - EndpointResourceInvalidRequestException, - EndpointUnsupportedInferenceTypeException, - InvalidRequestException, - LLMFineTuningMethodNotImplementedException, - UpstreamServiceError, -) -from llm_engine_server.domain.use_cases.llm_fine_tuning_use_cases import ( - CancelFineTuneJobV1UseCase, - CreateFineTuneJobV1UseCase, - GetFineTuneJobV1UseCase, - ListFineTuneJobV1UseCase, -) -from llm_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( - CompletionStreamV1UseCase, - CompletionSyncV1UseCase, - CreateLLMModelEndpointV1UseCase, - GetLLMModelEndpointByNameV1UseCase, - ListLLMModelEndpointsV1UseCase, -) -from llm_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase -from sse_starlette.sse import EventSourceResponse - -llm_router_v1 = APIRouter(prefix="/v1/llm") -logger = make_logger(filename_wo_ext(__name__)) - - -@llm_router_v1.post("/model-endpoints", response_model=CreateLLMModelEndpointV1Response) -async def create_model_endpoint( - request: CreateLLMModelEndpointV1Request, - auth: User = Depends(verify_authentication), - external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), -) -> CreateLLMModelEndpointV1Response: - """ - Creates an LLM endpoint for the current user. - """ - add_trace_resource_name("llm_model_endpoints_post") - logger.info(f"POST /llm/model-endpoints with {request} for {auth}") - try: - create_model_bundle_use_case = CreateModelBundleV2UseCase( - model_bundle_repository=external_interfaces.model_bundle_repository, - docker_repository=external_interfaces.docker_repository, - model_primitive_gateway=external_interfaces.model_primitive_gateway, - ) - use_case = CreateLLMModelEndpointV1UseCase( - create_model_bundle_use_case=create_model_bundle_use_case, - model_bundle_repository=external_interfaces.model_bundle_repository, - model_endpoint_service=external_interfaces.model_endpoint_service, - ) - return await use_case.execute(user=auth, request=request) - except ObjectAlreadyExistsException as exc: - raise HTTPException( - status_code=400, - detail="The specified model endpoint already exists.", - ) from exc - except EndpointLabelsException as exc: - raise HTTPException( - status_code=400, - detail=str(exc), - ) from exc - except ObjectHasInvalidValueException as exc: - raise HTTPException(status_code=400, detail=str(exc)) - except EndpointResourceInvalidRequestException as exc: - raise HTTPException( - status_code=400, - detail=str(exc), - ) from exc - except ObjectNotApprovedException as exc: - raise HTTPException( - status_code=403, - detail="The specified model bundle was not approved yet.", - ) from exc - except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: - raise HTTPException( - status_code=404, - detail="The specified model bundle could not be found.", - ) from exc - - -@llm_router_v1.get("/model-endpoints", response_model=ListLLMModelEndpointsV1Response) -async def list_model_endpoints( - name: Optional[str] = Query(default=None), - order_by: Optional[ModelEndpointOrderBy] = Query(default=None), - auth: User = Depends(verify_authentication), - external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), -) -> ListLLMModelEndpointsV1Response: - """ - Lists the LLM model endpoints owned by the current owner, plus all public_inference LLMs. - """ - add_trace_resource_name("llm_model_endpoints_get") - logger.info(f"GET /llm/model-endpoints?name={name}&order_by={order_by} for {auth}") - use_case = ListLLMModelEndpointsV1UseCase( - llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, - ) - return await use_case.execute(user=auth, name=name, order_by=order_by) - - -@llm_router_v1.get( - "/model-endpoints/{model_endpoint_name}", - response_model=GetLLMModelEndpointV1Response, -) -async def get_model_endpoint( - model_endpoint_name: str, - auth: User = Depends(verify_authentication), - external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), -) -> GetLLMModelEndpointV1Response: - """ - Describe the LLM Model endpoint with given name. - """ - add_trace_resource_name("llm_model_endpoints_name_get") - logger.info(f"GET /llm/model-endpoints/{model_endpoint_name} for {auth}") - try: - use_case = GetLLMModelEndpointByNameV1UseCase( - llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service - ) - return await use_case.execute(user=auth, model_endpoint_name=model_endpoint_name) - except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: - raise HTTPException( - status_code=404, - detail=f"Model Endpoint {model_endpoint_name} was not found.", - ) from exc - - -@llm_router_v1.post("/completions-sync", response_model=CompletionSyncV1Response) -async def create_completion_sync_task( - model_endpoint_name: str, - request: CompletionSyncV1Request, - auth: User = Depends(verify_authentication), - external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), -) -> CompletionSyncV1Response: - """ - Runs a sync prompt completion on an LLM. - """ - add_trace_resource_name("llm_completion_sync_post") - logger.info( - f"POST /completion_sync with {request} to endpoint {model_endpoint_name} for {auth}" - ) - try: - use_case = CompletionSyncV1UseCase( - model_endpoint_service=external_interfaces.model_endpoint_service, - llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, - ) - return await use_case.execute( - user=auth, model_endpoint_name=model_endpoint_name, request=request - ) - except UpstreamServiceError as exc: - return CompletionSyncV1Response( - status=TaskStatus.FAILURE, outputs=[], traceback=exc.content.decode() - ) - except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: - raise HTTPException( - status_code=404, - detail="The specified endpoint could not be found.", - ) from exc - except ObjectHasInvalidValueException as exc: - raise HTTPException(status_code=400, detail=str(exc)) - except EndpointUnsupportedInferenceTypeException as exc: - raise HTTPException( - status_code=400, - detail=f"Unsupported inference type: {str(exc)}", - ) from exc - - -@llm_router_v1.post("/completions-stream", response_model=CompletionStreamV1Response) -async def create_completion_stream_task( - model_endpoint_name: str, - request: CompletionStreamV1Request, - auth: User = Depends(verify_authentication), - external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), -) -> EventSourceResponse: - """ - Runs a stream prompt completion on an LLM. - """ - add_trace_resource_name("llm_completion_stream_post") - logger.info( - f"POST /completion_stream with {request} to endpoint {model_endpoint_name} for {auth}" - ) - try: - use_case = CompletionStreamV1UseCase( - model_endpoint_service=external_interfaces.model_endpoint_service, - llm_model_endpoint_service=external_interfaces.llm_model_endpoint_service, - ) - response = use_case.execute( - user=auth, model_endpoint_name=model_endpoint_name, request=request - ) - - async def event_generator(): - async for message in response: - yield {"data": message.json()} - - return EventSourceResponse(event_generator()) - except UpstreamServiceError as exc: - return EventSourceResponse( - iter( - ( - CompletionStreamV1Response( - status=TaskStatus.FAILURE, traceback=exc.content.decode() - ).json(), - ) - ) - ) - except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: - raise HTTPException( - status_code=404, - detail="The specified endpoint could not be found.", - ) from exc - except ObjectHasInvalidValueException as exc: - raise HTTPException(status_code=400, detail=str(exc)) - except EndpointUnsupportedInferenceTypeException as exc: - raise HTTPException( - status_code=400, - detail=f"Unsupported inference type: {str(exc)}", - ) from exc - - -@llm_router_v1.post("/fine-tunes", response_model=CreateFineTuneJobResponse) -async def create_fine_tune_job( - request: CreateFineTuneJobRequest, - auth: User = Depends(verify_authentication), - external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), -) -> CreateFineTuneJobResponse: - add_trace_resource_name("fine_tunes_create") - logger.info(f"POST /fine-tunes with {request} for {auth}") - try: - use_case = CreateFineTuneJobV1UseCase( - llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, - ) - return await use_case.execute(user=auth, request=request) - except (LLMFineTuningMethodNotImplementedException, InvalidRequestException) as exc: - raise HTTPException( - status_code=400, - detail=str(exc), - ) from exc - - -@llm_router_v1.get("/fine-tunes/{fine_tune_id}", response_model=GetFineTuneJobResponse) -async def get_fine_tune_job( - fine_tune_id: str, - auth: User = Depends(verify_authentication), - external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), -) -> GetFineTuneJobResponse: - add_trace_resource_name("fine_tunes_get") - logger.info(f"GET /fine-tunes/{fine_tune_id} for {auth}") - try: - use_case = GetFineTuneJobV1UseCase( - llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, - ) - return await use_case.execute(user=auth, fine_tune_id=fine_tune_id) - except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: - raise HTTPException( - status_code=404, - detail="The specified fine-tune job could not be found.", - ) from exc - - -@llm_router_v1.get("/fine-tunes", response_model=ListFineTuneJobResponse) -async def list_fine_tune_jobs( - auth: User = Depends(verify_authentication), - external_interfaces: ExternalInterfaces = Depends(get_external_interfaces_read_only), -) -> ListFineTuneJobResponse: - add_trace_resource_name("fine_tunes_list") - logger.info(f"GET /fine-tunes for {auth}") - use_case = ListFineTuneJobV1UseCase( - llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, - ) - return await use_case.execute(user=auth) - - -@llm_router_v1.put("/fine-tunes/{fine_tune_id}/cancel", response_model=CancelFineTuneJobResponse) -async def cancel_fine_tune_job( - fine_tune_id: str, - auth: User = Depends(verify_authentication), - external_interfaces: ExternalInterfaces = Depends(get_external_interfaces), -) -> CancelFineTuneJobResponse: - add_trace_resource_name("fine_tunes_cancel") - logger.info(f"PUT /fine-tunes/{fine_tune_id}/cancel for {auth}") - try: - use_case = CancelFineTuneJobV1UseCase( - llm_fine_tuning_service=external_interfaces.llm_fine_tuning_service, - ) - return await use_case.execute(user=auth, fine_tune_id=fine_tune_id) - except (ObjectNotFoundException, ObjectNotAuthorizedException) as exc: - raise HTTPException( - status_code=404, - detail="The specified fine-tune job could not be found.", - ) from exc diff --git a/server/llm_engine_server/api/worker.py b/server/llm_engine_server/api/worker.py deleted file mode 100644 index 95b02c59..00000000 --- a/server/llm_engine_server/api/worker.py +++ /dev/null @@ -1,17 +0,0 @@ -from uvicorn.workers import UvicornWorker - -# The target concurrency is around 50, so we set the limit to 32 with 4 workers -# for a total concurrency of 128 to allow for some headroom. -CONCURRENCY_LIMIT = 32 - - -class LLMEngineWorker(UvicornWorker): - """Overrides the configuration of the Uvicorn Worker.""" - - # uvloop and httptools are both faster than their alternatives, but they are not compatible - # with Windows or PyPy. - CONFIG_KWARGS = { - "loop": "uvloop", - "http": "httptools", - "limit_concurrency": CONCURRENCY_LIMIT, - } diff --git a/server/llm_engine_server/common/config.py b/server/llm_engine_server/common/config.py deleted file mode 100644 index 11d5fede..00000000 --- a/server/llm_engine_server/common/config.py +++ /dev/null @@ -1,52 +0,0 @@ -# Keep in line with service_config_{*}.yaml -# This file loads sensitive data that shouldn't make it to inference docker images -# Do not include this file in our inference/endpoint code -import os -from dataclasses import dataclass -from pathlib import Path -from typing import Sequence - -import yaml -from llm_engine_server.core.loggers import filename_wo_ext, make_logger - -logger = make_logger(filename_wo_ext(__file__)) - -__all__: Sequence[str] = ( - "DEFAULT_SERVICE_CONFIG_PATH", - "SERVICE_CONFIG_PATH", - "HostedModelInferenceServiceConfig", - "hmi_config", -) - -DEFAULT_SERVICE_CONFIG_PATH = str( - ( - Path(__file__).absolute().parent.parent.parent / "service_configs" / "service_config.yaml" - ).absolute() -) - -SERVICE_CONFIG_PATH = os.environ.get("DEPLOY_SERVICE_CONFIG_PATH", DEFAULT_SERVICE_CONFIG_PATH) - - -@dataclass -class HostedModelInferenceServiceConfig: - endpoint_namespace: str - cache_redis_url: str - sqs_profile: str - sqs_queue_policy_template: str - sqs_queue_tag_template: str - s3_file_llm_fine_tuning_job_repository: str - datadog_trace_enabled: str - - @classmethod - def from_yaml(cls, yaml_path): - with open(yaml_path, "r") as f: - raw_data = yaml.safe_load(f) - return HostedModelInferenceServiceConfig(**raw_data) - - -def read_default_config(): - logger.info(f"Using config file path: `{SERVICE_CONFIG_PATH}`") - return HostedModelInferenceServiceConfig.from_yaml(SERVICE_CONFIG_PATH) - - -hmi_config = read_default_config() diff --git a/server/llm_engine_server/common/constants.py b/server/llm_engine_server/common/constants.py deleted file mode 100644 index 87048c20..00000000 --- a/server/llm_engine_server/common/constants.py +++ /dev/null @@ -1,13 +0,0 @@ -from pathlib import Path - -CALLBACK_POST_INFERENCE_HOOK: str = "callback" -READYZ_FPATH: str = "/tmp/readyz" -DEFAULT_CELERY_TASK_NAME: str = "llm_engine_server.inference.async_inference.tasks.predict" -LIRA_CELERY_TASK_NAME: str = "llm_engine_server.inference.celery_service.exec_func" # TODO: FIXME - -PROJECT_ROOT: Path = Path(__file__).parents[2].absolute() -HOSTED_MODEL_INFERENCE_ROOT: Path = PROJECT_ROOT / "llm_engine" - -FEATURE_FLAG_USE_MULTI_CONTAINER_ARCHITECTURE_FOR_ARTIFACTLIKE_BUNDLE: str = ( - "USE_MULTI_CONTAINER_ARCHITECTURE_FOR_ARTIFACTLIKE_BUNDLE" -) diff --git a/server/llm_engine_server/common/datadog_utils.py b/server/llm_engine_server/common/datadog_utils.py deleted file mode 100644 index a7ee6a4e..00000000 --- a/server/llm_engine_server/common/datadog_utils.py +++ /dev/null @@ -1,10 +0,0 @@ -from ddtrace import tracer - - -def add_trace_resource_name(tag: str): - """Adds a custom tag to a given dd trace corresponding to the route - (e.g. get_model_bundles for GET /model-bundles, etc.) so that we can filter in Datadog easier - """ - current_span = tracer.current_span() - if current_span: - current_span.set_tag("llm_engine_server.resource_name", tag) diff --git a/server/llm_engine_server/common/dtos/llms.py b/server/llm_engine_server/common/dtos/llms.py deleted file mode 100644 index 17d6f13d..00000000 --- a/server/llm_engine_server/common/dtos/llms.py +++ /dev/null @@ -1,171 +0,0 @@ -""" -DTOs for LLM APIs. -""" - -from typing import Any, Dict, List, Optional - -from llm_engine_server.common.dtos.model_endpoints import ( - CpuSpecificationType, - GetModelEndpointV1Response, - GpuType, - ModelEndpointType, - StorageSpecificationType, -) -from llm_engine_server.domain.entities import ( - BatchJobStatus, - CallbackAuth, - LLMInferenceFramework, - LLMSource, - Quantization, -) -from pydantic import BaseModel, Field, HttpUrl - -from .tasks import TaskStatus - - -class CreateLLMModelEndpointV1Request(BaseModel): - name: str - - # LLM specific fields - model_name: str - source: LLMSource = LLMSource.HUGGING_FACE - inference_framework: LLMInferenceFramework = LLMInferenceFramework.DEEPSPEED - inference_framework_image_tag: str - num_shards: int = 1 - """ - Number of shards to distribute the model onto GPUs. Only affects behavior for text-generation-inference models - """ - - quantize: Optional[Quantization] = None - """ - Whether to quantize the model. Only affect behavior for text-generation-inference models - """ - - checkpoint_path: Optional[str] = None - """ - Path to the checkpoint to load the model from. Only affects behavior for text-generation-inference models - """ - - # General endpoint fields - metadata: Dict[str, Any] # TODO: JSON type - post_inference_hooks: Optional[List[str]] - endpoint_type: ModelEndpointType = ModelEndpointType.SYNC - cpus: CpuSpecificationType - gpus: int - memory: StorageSpecificationType - gpu_type: GpuType - storage: Optional[StorageSpecificationType] - optimize_costs: Optional[bool] - min_workers: int - max_workers: int - per_worker: int - labels: Dict[str, str] - prewarm: Optional[bool] - high_priority: Optional[bool] - default_callback_url: Optional[HttpUrl] - default_callback_auth: Optional[CallbackAuth] - public_inference: Optional[bool] = True # LLM endpoints are public by default. - - -class CreateLLMModelEndpointV1Response(BaseModel): - endpoint_creation_task_id: str - - -class GetLLMModelEndpointV1Response(BaseModel): - id: str - """ - The autogenerated ID of the LLMEngine endpoint. - """ - - name: str - model_name: str - source: LLMSource - inference_framework: LLMInferenceFramework - inference_framework_image_tag: str - num_shards: int - quantize: Optional[Quantization] = None - spec: GetModelEndpointV1Response - - -class ListLLMModelEndpointsV1Response(BaseModel): - model_endpoints: List[GetLLMModelEndpointV1Response] - - -# Delete and update use the default LLMEngine endpoint APIs. - - -class CompletionSyncV1Request(BaseModel): - """ - Request object for a synchronous prompt completion task. - """ - - prompts: List[str] - max_new_tokens: int - temperature: float = Field(gt=0, le=100) - - -class CompletionOutput(BaseModel): - text: str - num_completion_tokens: int - - -class CompletionSyncV1Response(BaseModel): - """ - Response object for a synchronous prompt completion task. - """ - - status: TaskStatus - outputs: List[CompletionOutput] - traceback: Optional[str] = None - - -class CompletionStreamV1Request(BaseModel): - """ - Request object for a stream prompt completion task. - """ - - prompt: str - max_new_tokens: int - temperature: float = Field(gt=0, le=100) - - -class CompletionStreamOutput(BaseModel): - text: str - finished: bool - num_completion_tokens: Optional[int] = None - - -class CompletionStreamV1Response(BaseModel): - """ - Response object for a stream prompt completion task. - """ - - status: TaskStatus - output: Optional[CompletionStreamOutput] = None - traceback: Optional[str] = None - - -class CreateFineTuneJobRequest(BaseModel): - training_file: str - validation_file: str - model_name: str - base_model: str # TODO enum - fine_tuning_method: str # TODO enum - hyperparameters: Dict[str, str] # TODO validated somewhere else - - -class CreateFineTuneJobResponse(BaseModel): - fine_tune_id: str - - -class GetFineTuneJobResponse(BaseModel): - fine_tune_id: str - status: BatchJobStatus - - -class ListFineTuneJobResponse(BaseModel): - jobs: List[GetFineTuneJobResponse] - - -class CancelFineTuneJobResponse(BaseModel): - success: bool diff --git a/server/llm_engine_server/common/dtos/resource_manager.py b/server/llm_engine_server/common/dtos/resource_manager.py deleted file mode 100644 index cb6bea9a..00000000 --- a/server/llm_engine_server/common/dtos/resource_manager.py +++ /dev/null @@ -1,7 +0,0 @@ -from llm_engine_server.common.dtos.endpoint_builder import BuildEndpointRequest -from pydantic import BaseModel - - -class CreateOrUpdateResourcesRequest(BaseModel): - build_endpoint_request: BuildEndpointRequest - image: str diff --git a/server/llm_engine_server/common/env_vars.py b/server/llm_engine_server/common/env_vars.py deleted file mode 100644 index fac2325c..00000000 --- a/server/llm_engine_server/common/env_vars.py +++ /dev/null @@ -1,67 +0,0 @@ -""" -A place for defining, setting, and referencing all environment variables used in LLMEngine. -""" -import os -from typing import Optional, Sequence - -from llm_engine_server.common.constants import PROJECT_ROOT -from llm_engine_server.core.loggers import logger_name, make_logger - -__all__: Sequence[str] = ( - "CIRCLECI", - "LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH", - "LLM_ENGINE_SERVICE_TEMPLATE_FOLDER", - "LOCAL", - "WORKSPACE", - "get_boolean_env_var", -) - -logger = make_logger(logger_name()) - - -def get_boolean_env_var(name: str) -> bool: - """For all env vars that are either on or off. - - An env var is ON iff: - - it is defined - - its value is the literal string 'true' - - If it is present but not set to 'true', it is considered to be OFF. - """ - value = os.environ.get(name) - if value is None: - return False - value = value.strip().lower() - return "true" == value - - -CIRCLECI: bool = get_boolean_env_var("CIRCLECI") - -LOCAL: bool = get_boolean_env_var("LOCAL") -"""Indicates that LLMEngine is running in a local development environment. Also used for local testing. -""" - -WORKSPACE: str = os.environ.get("WORKSPACE", "~/models") -"""The working directory where llm_engine is installed. -""" - -LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH: str = os.environ.get( - "LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH", - os.path.join( - PROJECT_ROOT, - "llm_engine_server/infra/gateways/resources/templates", - "service_template_config_map_circleci.yaml", - ), -) -"""The path to the config map containing the LLMEngine service template. -""" - -LLM_ENGINE_SERVICE_TEMPLATE_FOLDER: Optional[str] = os.environ.get( - "LLM_ENGINE_SERVICE_TEMPLATE_FOLDER" -) -"""The path to the folder containing the LLMEngine service template. If set, this overrides -LLM_ENGINE_SERVICE_TEMPLATE_CONFIG_MAP_PATH. -""" - -if LOCAL: - logger.warning("LOCAL development & testing mode is ON") diff --git a/server/llm_engine_server/common/io.py b/server/llm_engine_server/common/io.py deleted file mode 100644 index 8ee049db..00000000 --- a/server/llm_engine_server/common/io.py +++ /dev/null @@ -1,14 +0,0 @@ -"""LLMEngine Input/Output utils.""" -import os - -import boto3 -import smart_open - - -def open_wrapper(uri: str, mode: str = "rt", **kwargs): - # This follows the 5.1.0 smart_open API - profile_name = kwargs.get("aws_profile", os.getenv("AWS_PROFILE")) - session = boto3.Session(profile_name=profile_name) - client = session.client("s3") - transport_params = {"client": client} - return smart_open.open(uri, mode, transport_params=transport_params) diff --git a/server/llm_engine_server/common/pydantic_types/endpoint_predict_payload.py b/server/llm_engine_server/common/pydantic_types/endpoint_predict_payload.py deleted file mode 100644 index 218099a1..00000000 --- a/server/llm_engine_server/common/pydantic_types/endpoint_predict_payload.py +++ /dev/null @@ -1,10 +0,0 @@ -from typing import Any, Optional - -from pydantic import BaseModel - - -class EndpointPredictPayload(BaseModel): - url: Optional[str] = None - args: Optional[Any] = None - cloudpickle: Optional[str] = None - return_pickled: bool diff --git a/server/llm_engine_server/common/settings.py b/server/llm_engine_server/common/settings.py deleted file mode 100644 index 4d0af5e4..00000000 --- a/server/llm_engine_server/common/settings.py +++ /dev/null @@ -1,66 +0,0 @@ -# This file contains standard settings for ML serve. -# - -import hashlib -from typing import List - -from llm_engine_server.core.config import ml_infra_config - -DEPLOYMENT_PREFIX = "llm-engine" -SERVICE_BUILDER_QUEUE_PREFIX = "llm-engine" -SERVICE_BUILDER_QUEUE_SUFFIX = "service-builder" - -RESTRICTED_ENDPOINT_LABELS = set( - [ - "user_id", - "endpoint_name", - ] -) - -REQUIRED_ENDPOINT_LABELS = set( - [ - "team", - "product", - ] -) - -PRETRAINED_ENDPOINTS_CREATED_BY = ["nucleus-model-zoo", "bloom", "llm", "pretrained"] - - -def generate_deployment_name(user_id, endpoint_name): - return "-".join(_generate_deployment_name_parts(user_id, endpoint_name)) - - -def _generate_queue_name(user_id, endpoint_name): - return ".".join(_generate_deployment_name_parts(user_id, endpoint_name)) - - -def generate_destination(user_id: str, endpoint_name: str, endpoint_type: str) -> str: - if endpoint_type == "async": - return _generate_queue_name(user_id, endpoint_name) - elif endpoint_type in {"sync", "streaming"}: - return generate_deployment_name(user_id, endpoint_name) - else: - raise ValueError(f"Invalid endpoint_type: {endpoint_type}") - - -def _generate_deployment_name_parts(user_id: str, endpoint_name: str) -> List[str]: - user_endpoint_hash = hashlib.md5((user_id + endpoint_name).encode("utf-8")).hexdigest() - return [ - DEPLOYMENT_PREFIX, - user_id[:24], - endpoint_name[:8], - user_endpoint_hash[:8], - ] - - -def get_service_builder_queue(service_identifier=None): - return ( - f"{SERVICE_BUILDER_QUEUE_PREFIX}-{service_identifier}.{SERVICE_BUILDER_QUEUE_SUFFIX}" - if service_identifier - else f"{SERVICE_BUILDER_QUEUE_PREFIX}.{SERVICE_BUILDER_QUEUE_SUFFIX}" - ) - - -def get_service_builder_logs_location(user_id: str, endpoint_name: str): - return f"s3://{ml_infra_config().s3_bucket}/service_builder_logs/{user_id}_{endpoint_name}" diff --git a/server/llm_engine_server/core/auth/fake_authentication_repository.py b/server/llm_engine_server/core/auth/fake_authentication_repository.py deleted file mode 100644 index 5da02827..00000000 --- a/server/llm_engine_server/core/auth/fake_authentication_repository.py +++ /dev/null @@ -1,24 +0,0 @@ -from typing import Dict, Optional - -from llm_engine_server.core.auth.authentication_repository import AuthenticationRepository, User - - -class FakeAuthenticationRepository(AuthenticationRepository): - def __init__(self, user_team_override: Optional[Dict[str, str]] = None): - if user_team_override is None: - user_team_override = {} - self.user_team_override = user_team_override - - def get_auth_from_user_id(self, user_id: str) -> Optional[User]: - team_id = self.user_team_override.get(user_id, user_id) - return User(user_id=user_id, team_id=team_id, is_privileged_user=True) - - async def get_auth_from_user_id_async(self, user_id: str) -> Optional[User]: - team_id = self.user_team_override.get(user_id, user_id) - return User(user_id=user_id, team_id=team_id, is_privileged_user=True) - - def get_auth_from_api_key(self, api_key: str) -> Optional[User]: - return User(user_id=api_key, team_id=api_key, is_privileged_user=True) - - async def get_auth_from_api_key_async(self, api_key: str) -> Optional[User]: - return User(user_id=api_key, team_id=api_key, is_privileged_user=True) diff --git a/server/llm_engine_server/core/aws/sfn_client.py b/server/llm_engine_server/core/aws/sfn_client.py deleted file mode 100644 index 30b6593f..00000000 --- a/server/llm_engine_server/core/aws/sfn_client.py +++ /dev/null @@ -1,21 +0,0 @@ -"""This module provides a client for the AWS Step Functions service.""" -import os -from typing import Optional - -from botocore.client import BaseClient -from llm_engine_server.core.aws.roles import session -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import logger_name, make_logger - -logger = make_logger(logger_name()) - - -def sync_sfn_client(**kwargs) -> Optional[BaseClient]: - is_testing_mode = os.environ.get("TESTING_DISABLE_SFN", "").lower() == "true" - if is_testing_mode: - logger.error( - "Not creating step function client as we are in testing mode." - "THIS SHOULD NOT HAPPEN IN PRODUCTION!" - ) - return None - return session(ml_infra_config().profile_ml_worker).client("stepfunctions", **kwargs) diff --git a/server/llm_engine_server/core/celery/__init__.py b/server/llm_engine_server/core/celery/__init__.py deleted file mode 100644 index cb4eb189..00000000 --- a/server/llm_engine_server/core/celery/__init__.py +++ /dev/null @@ -1,8 +0,0 @@ -from typing import Sequence - -from .app import TaskVisibility, celery_app - -__all__: Sequence[str] = ( - "celery_app", - "TaskVisibility", -) diff --git a/server/llm_engine_server/core/config.py b/server/llm_engine_server/core/config.py deleted file mode 100644 index b3a02a85..00000000 --- a/server/llm_engine_server/core/config.py +++ /dev/null @@ -1,90 +0,0 @@ -"""AWS configuration for ml-infra-services. - -The configuration file is loaded from the ML_INFRA_SERVICES_CONFIG_PATH environment variable. -If this is not set, the default configuration file is used from -llm_engine_server.core/configs/default.yaml. -""" -import os -from contextlib import contextmanager -from copy import deepcopy -from dataclasses import dataclass -from pathlib import Path -from typing import Optional, Sequence - -import yaml -from llm_engine_server.core.loggers import filename_wo_ext, make_logger - -logger = make_logger(filename_wo_ext(__file__)) - -__all__: Sequence[str] = ( - "DEFAULT_CONFIG_PATH", - "CONFIG_PATH", - "config_context", - "get_config_path_for_env_name", - "ml_infra_config", - "use_config_context", -) - -DEFAULT_CONFIG_PATH = Path(__file__).parent / "configs" / "circleci.yaml" -CONFIG_PATH: str = os.getenv("ML_INFRA_SERVICES_CONFIG_PATH", str(DEFAULT_CONFIG_PATH)) - - -@dataclass -class MLInfraServicesConfig: - env: str - k8s_cluster_name: str - dns_host_domain: str - default_region: str - ml_account_id: str - docker_repo_prefix: str - redis_host: str - s3_bucket: str - profile_ml_worker: str = "default" - profile_ml_inference_worker: str = "default" - - @classmethod - def from_yaml(cls, yaml_path) -> "MLInfraServicesConfig": - with open(yaml_path, "r") as f: - raw_data = yaml.safe_load(f) - return MLInfraServicesConfig(**raw_data) - - -def read_default_config(): - logger.info(f"Using config file path: `{CONFIG_PATH}`") - return MLInfraServicesConfig.from_yaml(CONFIG_PATH) - - -_ml_infra_config: Optional[MLInfraServicesConfig] = None - - -def ml_infra_config() -> MLInfraServicesConfig: - global _ml_infra_config - if _ml_infra_config is None: - _ml_infra_config = read_default_config() - return _ml_infra_config - - -@contextmanager -def config_context(config_path: str): - """Context manager that temporarily changes the config file path.""" - global _ml_infra_config - current_config = deepcopy(_ml_infra_config) - try: - _ml_infra_config = MLInfraServicesConfig.from_yaml(config_path) - yield - finally: - _ml_infra_config = current_config - - -def use_config_context(config_path: str): - """Use the config file at the given path.""" - global _ml_infra_config - _ml_infra_config = MLInfraServicesConfig.from_yaml(config_path) - - -def get_config_path_for_env_name(env_name: str) -> Path: - path = DEFAULT_CONFIG_PATH.parent / f"{env_name}.yaml" - if not path.exists(): - print(path) - raise ValueError(f"Config file does not exist for env: {env_name}") - return path diff --git a/server/llm_engine_server/core/domain_exceptions.py b/server/llm_engine_server/core/domain_exceptions.py deleted file mode 100644 index 62068614..00000000 --- a/server/llm_engine_server/core/domain_exceptions.py +++ /dev/null @@ -1,59 +0,0 @@ -from dataclasses import dataclass - - -class DomainException(Exception): - """ - Base class for exceptions thrown for domain (business logic) errors. - """ - - -class ObjectAlreadyExistsException(DomainException): - """ - Thrown when the user tries to create a model with a name that already exists. - """ - - -class ObjectNotFoundException(DomainException): - """ - Thrown when a required object is not found, e.g. when creating a version for a nonexistent model - """ - - -class ObjectNotAuthorizedException(DomainException): - """ - Thrown when a user tries to access an object they don't own. - """ - - -class ObjectHasInvalidValueException(DomainException, ValueError): - """ - Thrown when a user tries to create an object with an invalid value. - """ - - -class ObjectNotApprovedException(DomainException): - """ - Thrown when a required object is not approved, e.g. for a Bundle in review. - """ - - -@dataclass -class DockerImageNotFoundException(DomainException): - """ - Thrown when a user tries to specify a custom Docker image that cannot be found. - """ - - repository: str - tag: str - - -class DockerBuildFailedException(DomainException): - """ - Thrown if the server failed to build a docker image. - """ - - -class ReadOnlyDatabaseException(DomainException): - """ - Thrown if the server attempted to write to a read-only database. - """ diff --git a/server/llm_engine_server/core/kubernetes.py b/server/llm_engine_server/core/kubernetes.py deleted file mode 100644 index 59589d7c..00000000 --- a/server/llm_engine_server/core/kubernetes.py +++ /dev/null @@ -1,81 +0,0 @@ -import logging -from enum import Enum -from pathlib import Path -from string import Template -from typing import Iterator, Union - -import yaml -from kubeconfig import KubeConfig - -from .loggers import make_logger - -logger = make_logger(__file__, log_level=logging.DEBUG) -_config = KubeConfig() - -_K8S_CONFIGS = {} - - -class LifecycleSelector(str, Enum): - NORMAL = "normal" - SPOT = "spot" - - -def k8s_config() -> str: - """Returns the name of the current kubernetes context""" - return _config.view()["current-context"].strip() - - -def check_k8s_config(env_name: str) -> bool: - """ - Checks whether the current k8s context (i.e. which cluster you're on) - is the one given by the config. - """ - assert env_name in _K8S_CONFIGS - cur_config = k8s_config() - return cur_config.strip() == _K8S_CONFIGS[env_name].strip() - - -def substitute_yaml(fp: Union[str, Path], **kwargs) -> dict: - """Read a file from disk, substitute options, return yaml - - The yaml file must have the variables to substitute written as $VAR or ${VAR}. See documentation - for string.Template for more details. - - Args: - fp: path to a yaml file - **kwargs: all the keyword arguments needed to substitute flags in the yaml file - - Returns: - Returns a dict of parsed yaml - - Raises: - FileNotFoundError: If no file exists at the path - KeyError: If a keyword argument is specified for a key that doesn't exist, or a key is - specified and no corresponding argument is passed in. - """ - with open(fp, "r") as template_f: - config = yaml.safe_load(Template(template_f.read()).substitute(**kwargs)) - return config - - -def substitute_yamls(fp: Union[str, Path], **kwargs) -> Iterator: - """Read a file from disk, substitute options, return yaml - - The yaml file must have the variables to substitute written as $VAR or ${VAR}. See documentation - for string.Template for more details. - - Args: - fp: path to a yaml file - **kwargs: all the keyword arguments needed to substitute flags in the yaml file - - Returns: - Returns a list of dicts of parsed yaml - - Raises: - FileNotFoundError: If no file exists at the path - KeyError: If a keyword argument is specified for a key that doesn't exist, or a key is - specified and no corresponding argument is passed in. - """ - with open(fp, "r") as template_f: - config = yaml.safe_load_all(Template(template_f.read()).substitute(**kwargs)) - return config diff --git a/server/llm_engine_server/core/testing_utilities.py b/server/llm_engine_server/core/testing_utilities.py deleted file mode 100644 index 5b80dd9d..00000000 --- a/server/llm_engine_server/core/testing_utilities.py +++ /dev/null @@ -1,140 +0,0 @@ -"""Utility functions for Python programs. Should not be used by other modules in this package.""" -import os -import platform -from functools import lru_cache -from tempfile import NamedTemporaryFile -from typing import Callable, Iterable, Optional, Sequence, Tuple, TypeVar - -from llm_engine_server.core.aws.storage_client import sync_storage_client -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.utils.url import parse_attachment_url - -In = TypeVar("In") -"""Type variable representing the function under test's input type. -""" - -Out = TypeVar("Out") -"""Type variable representing the function under test's output type. -""" - -__all__: Sequence[str] = ( - "table_tester", - "no_aws_r_creds", - "no_aws_rw_creds", - "env_var_is_true", -) - - -def table_tester( - fn: Callable[[In], Out], - i_o_pairs: Iterable[Tuple[In, Out]], - equality: Callable[[Out, Out], bool] = lambda a, b: a == b, -) -> None: - """Convenience function to apply a function against a series of input & expected output pairs. - This function `assert`s that the function applied to each input results in the associated output - value, where equality is checked by the :param:`equality` function, which defaults to Python's `==`. - """ - for i, (inp, expected) in enumerate(i_o_pairs): - msg_part = f"Failed on test pair # {i + 1}:\nINPUT: {inp}\nEXPECTED: {expected}\n" - try: - actual = fn(inp) - except Exception: # pylint: disable=broad-except - print(msg_part) - raise - assert equality(actual, expected), msg_part + f"ACTUAL: {actual}" - - -@lru_cache(1) -def no_aws_r_creds() -> bool: - """True if we don't have the read AWS access credentials to run tests. False means we do. - - Useful in a `@pytest.mark.skipif(condition=no_aws_r_creds(), reason="No AWS read credentials")` - marker on a `test_` unittest function. - """ - return _no_aws_creds(write_check=False) - - -@lru_cache(1) -def no_aws_rw_creds() -> bool: - """True if we don't have the read+write AWS access credentials to run tests. False means we do. - - Useful in a `@pytest.mark.skipif(condition=no_aws_rw_creds(), reason="No AWS read/write credentials")` - marker on a `test_` unittest function. - """ - return _no_aws_creds(write_check=True) - - -def _no_aws_creds(*, write_check: bool) -> bool: - try: - p = parse_attachment_url(f"s3://{ml_infra_config().s3_bucket}/testing/_keep_do_not_delete") - s3_client = sync_storage_client() - if not _exists(s3_client, p): - return True - - with NamedTemporaryFile() as f: - f.close() - # test read - with open(f.name, "wb") as wb: - s3_client.download_fileobj( - Bucket=p.bucket, - Key=p.key, - Fileobj=wb, - ) - if write_check: - # test write - with open(f.name, "rb") as rb: - s3_client.upload_fileobj( - Fileobj=rb, - Bucket=p.bucket, - Key=p.key, - ) - except Exception: # pylint: disable=broad-except - return True - else: - return False - - -def _exists(s3_client, p): - try: - # https://stackoverflow.com/questions/33842944/check-if-a-key-exists-in-a-bucket-in-s3-using-boto3 - s3_client.head_object(Bucket=p.bucket, Key=p.key) - except Exception as e: # type: ignore - try: - # pylint: disable=no-member - error_code = e.response["Error"]["Code"].strip() # type: ignore - if error_code in ("404", "NoSuchKey"): - return False - except (NameError, KeyError): - pass - raise e - else: - return True - - -def env_var_is_true(env_var_name: str) -> bool: - """Return true if the environment variable is currently set to a known truth value. - - True if the :param:`env_var_name` environment variable is present and contains a truth value. - The **only** accepted truth values are, case-insensitive: - - 'y' - - 'yes' - - 'true' - - - All other values are considered false. - Additionally, an unset environment variable will result in this function evaluating to false. - """ - if len(env_var_name) == 0: - raise ValueError("Need non-empty environment variable name!") - - try: - x: Optional[str] = os.environ.get(env_var_name, None) - if x is None: - return False - x = x.lower().strip() - return x in ("y", "true", "yes") - except Exception: # pylint: disable=broad-except - return False - - -def is_linux() -> bool: - return "Linux" in platform.platform() diff --git a/server/llm_engine_server/db/base.py b/server/llm_engine_server/db/base.py deleted file mode 100644 index 68f4a7bf..00000000 --- a/server/llm_engine_server/db/base.py +++ /dev/null @@ -1,139 +0,0 @@ -import asyncio -import os -import sys -from typing import Iterator, Optional - -import sqlalchemy -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from sqlalchemy import create_engine -from sqlalchemy.ext.asyncio import async_scoped_session, async_sessionmaker, create_async_engine -from sqlalchemy.ext.declarative import declarative_base -from sqlalchemy.orm import sessionmaker -from sqlalchemy.pool import NullPool - -logger = make_logger(filename_wo_ext(__file__)) - - -def get_engine_url(env: Optional[str] = None, read_only: bool = True, sync: bool = True) -> str: - """Gets the URL of the Postgresql engine depending on the environment.""" - if os.getenv("ML_INFRA_DATABASE_URL"): - # In CircleCI environment, we set up a test in another container and specify the URL. - engine_url = os.getenv("ML_INFRA_DATABASE_URL") - else: - assert "pytest" in sys.modules, "Must specify ML_INFRA_DATABASE_URL or be in a testing env." - # If we are in a local testing environment, we can set up a test psql instance. - # pylint: disable=import-outside-toplevel - import testing.postgresql - - Postgresql = testing.postgresql.PostgresqlFactory( - cache_initialized_db=True, - ) - postgresql = Postgresql() - engine_url = postgresql.url() - - assert engine_url - - # For async postgres, we need to use an async dialect. - if not sync: - engine_url = engine_url.replace("postgresql://", "postgresql+asyncpg://") - return engine_url - - -# Try pool_pre_ping=True, see -# https://docs.sqlalchemy.org/en/14/core/engines.html -# ?highlight=create_engine#sqlalchemy.create_engine.params.pool_pre_ping -# tl;dr is hopefully it stops the psycopg errors from happening -# Another probably less impactful (ie it shouldn't increase latency by as much, -# but also shouldn't get rid of as many errors e.g. 95% compared to 99.9%) -# option is to try pool_recycle = something kinda short e.g. a minute -# pool_pre_ping=True seems to not increase latency by very much -# (I profiled 2.7 ms -> 3.3 ms on GET model_bundles/) -# but hopefully should completely eliminate -# any of the postgres connection errors we've been seeing. - -ml_infra_pg_engine = create_engine( - get_engine_url(read_only=False, sync=True), - echo=False, - future=True, - pool_pre_ping=True, -) -ml_infra_pg_engine_read_only = create_engine( - get_engine_url(read_only=True, sync=True), - echo=False, - future=True, - pool_pre_ping=True, -) -ml_infra_pg_engine_async = create_async_engine( - get_engine_url(read_only=False, sync=False), - echo=False, - future=True, - pool_pre_ping=True, -) -ml_infra_pg_engine_read_only_async = create_async_engine( - get_engine_url(read_only=True, sync=False), - echo=False, - future=True, - pool_pre_ping=True, - max_overflow=5, -) -ml_infra_pg_engine_async_null_pool = create_async_engine( - get_engine_url(read_only=False, sync=False), - echo=False, - future=True, - poolclass=NullPool, - pool_pre_ping=True, -) - -# Synchronous sessions (Session and SessionReadOnly) are fairly straightforward, and both -# can be used at any time. To use asynchronous sqlalchemy, use the SessionAsyncNullPool -# if you're running a synchronous program where concurrency of database connections is not -# super important (e.g. Celery workers that use long-standing connections, and Celery is currently -# synchronous). Use SessionAsync and SessionReadOnlyAsync in ASGI applications. -Session = sessionmaker(autocommit=False, autoflush=False, bind=ml_infra_pg_engine) -SessionReadOnly = sessionmaker(autocommit=False, autoflush=False, bind=ml_infra_pg_engine_read_only) -SessionAsync = async_scoped_session( - session_factory=async_sessionmaker( - autocommit=False, - autoflush=False, - bind=ml_infra_pg_engine_async, - expire_on_commit=False, - ), - scopefunc=asyncio.current_task, -) -SessionAsyncNullPool = async_scoped_session( - session_factory=async_sessionmaker( - autocommit=False, - autoflush=False, - bind=ml_infra_pg_engine_async_null_pool, - expire_on_commit=False, - ), - scopefunc=asyncio.current_task, -) -SessionReadOnlyAsync = async_scoped_session( - async_sessionmaker( - autocommit=False, - autoflush=False, - bind=ml_infra_pg_engine_read_only_async, - expire_on_commit=False, - ), - scopefunc=asyncio.current_task, -) -Base = declarative_base() - - -def get_session_iterator() -> Iterator[sqlalchemy.orm.Session]: - """Utility to return an iterator with an instantiated session in the ML Infra database.""" - session = Session() - try: - yield session - finally: - session.close() - - -def get_read_only_session_iterator() -> Iterator[sqlalchemy.orm.Session]: - """Utility to return an iterator with an instantiated session in the ML Infra database.""" - session = SessionReadOnly() - try: - yield session - finally: - session.close() diff --git a/server/llm_engine_server/db/migrations/alembic.ini b/server/llm_engine_server/db/migrations/alembic.ini deleted file mode 100644 index 574eb9cc..00000000 --- a/server/llm_engine_server/db/migrations/alembic.ini +++ /dev/null @@ -1,85 +0,0 @@ -# A generic, single database configuration. - -[alembic] -# path to migration scripts -script_location = alembic - -# template used to generate migration files -# file_template = %%(rev)s_%%(slug)s - -# timezone to use when rendering the date -# within the migration file as well as the filename. -# string value is passed to dateutil.tz.gettz() -# leave blank for localtime -# timezone = - -# max length of characters to apply to the -# "slug" field -# truncate_slug_length = 40 - -# set to 'true' to run the environment during -# the 'revision' command, regardless of autogenerate -# revision_environment = false - -# set to 'true' to allow .pyc and .pyo files without -# a source .py file to be detected as revisions in the -# versions/ directory -# sourceless = false - -# version location specification; this defaults -# to alembic/versions. When using multiple version -# directories, initial revisions must be specified with --version-path -# version_locations = %(here)s/bar %(here)s/bat alembic/versions - -# the output encoding used when revision files -# are written from script.py.mako -# output_encoding = utf-8 - -sqlalchemy.url = driver://user:pass@localhost/dbname - - -[post_write_hooks] -# post_write_hooks defines scripts or Python functions that are run -# on newly generated revision scripts. See the documentation for further -# detail and examples - -# format using "black" - use the console_scripts runner, against the "black" entrypoint -# hooks=black -# black.type=console_scripts -# black.entrypoint=black -# black.options=-l 79 - -# Logging configuration -[loggers] -keys = root,sqlalchemy,alembic - -[handlers] -keys = console - -[formatters] -keys = generic - -[logger_root] -level = WARN -handlers = console -qualname = - -[logger_sqlalchemy] -level = DEBUG -handlers = -qualname = sqlalchemy.engine - -[logger_alembic] -level = INFO -handlers = -qualname = alembic - -[handler_console] -class = StreamHandler -args = (sys.stderr,) -level = NOTSET -formatter = generic - -[formatter_generic] -format = %(levelname)-5.5s [%(name)s] %(message)s -datefmt = %H:%M:%S diff --git a/server/llm_engine_server/db/migrations/alembic/README b/server/llm_engine_server/db/migrations/alembic/README deleted file mode 100644 index 98e4f9c4..00000000 --- a/server/llm_engine_server/db/migrations/alembic/README +++ /dev/null @@ -1 +0,0 @@ -Generic single-database configuration. \ No newline at end of file diff --git a/server/llm_engine_server/db/ml_infra_pg.py b/server/llm_engine_server/db/ml_infra_pg.py deleted file mode 100644 index 0a2d4852..00000000 --- a/server/llm_engine_server/db/ml_infra_pg.py +++ /dev/null @@ -1,10 +0,0 @@ -from .base import Base, ml_infra_pg_engine - -# we need to import the following for sqlalchemy -# pylint: disable=unused-import -from .models.llm_engine import Bundle, Endpoint # noqa -from .models.model import Model, ModelArtifact, ModelVersion # noqa -from .models.train import Execution, Experiment, Job, Snapshot # noqa - -# run this file to create the db models imported -Base.metadata.create_all(ml_infra_pg_engine) diff --git a/server/llm_engine_server/domain/entities/common_types.py b/server/llm_engine_server/domain/entities/common_types.py deleted file mode 100644 index 899a2973..00000000 --- a/server/llm_engine_server/domain/entities/common_types.py +++ /dev/null @@ -1,4 +0,0 @@ -from typing import Union - -CpuSpecificationType = Union[str, int, float] -StorageSpecificationType = Union[str, int, float] # TODO(phil): we can make this more specific. diff --git a/server/llm_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py b/server/llm_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py deleted file mode 100644 index 02a14990..00000000 --- a/server/llm_engine_server/domain/entities/docker_image_batch_job_bundle_entity.py +++ /dev/null @@ -1,27 +0,0 @@ -import datetime -from typing import Dict, List, Optional - -from llm_engine_server.domain.entities import GpuType -from llm_engine_server.domain.entities.owned_entity import OwnedEntity - - -class DockerImageBatchJobBundle(OwnedEntity): - id: str - name: str - created_by: str - created_at: datetime.datetime - owner: str - image_repository: str - image_tag: str - command: List[str] - env: Dict[str, str] - mount_location: Optional[str] - cpus: Optional[str] - memory: Optional[str] - storage: Optional[str] - gpus: Optional[int] - gpu_type: Optional[GpuType] - public: Optional[bool] - - class Config: - orm_mode = True diff --git a/server/llm_engine_server/domain/entities/gpu_type.py b/server/llm_engine_server/domain/entities/gpu_type.py deleted file mode 100644 index 99cfd1b4..00000000 --- a/server/llm_engine_server/domain/entities/gpu_type.py +++ /dev/null @@ -1,9 +0,0 @@ -from enum import Enum - - -class GpuType(str, Enum): - """Lists allowed GPU types for LLMEngine.""" - - NVIDIA_TESLA_T4 = "nvidia-tesla-t4" - NVIDIA_AMPERE_A10 = "nvidia-ampere-a10" - NVIDIA_AMPERE_A100 = "nvidia-a100" diff --git a/server/llm_engine_server/domain/entities/llm_fine_tune_job_entity.py b/server/llm_engine_server/domain/entities/llm_fine_tune_job_entity.py deleted file mode 100644 index 487483ae..00000000 --- a/server/llm_engine_server/domain/entities/llm_fine_tune_job_entity.py +++ /dev/null @@ -1,14 +0,0 @@ -from typing import Any, Dict, List - -from pydantic import BaseModel - - -class LLMFineTuneJobTemplate(BaseModel): - docker_image_batch_job_bundle_id: str - launch_bundle_config: Dict[str, Any] - launch_endpoint_config: Dict[str, Any] - default_hparams: Dict[str, Any] - required_params: List[str] - - class Config: - orm_mode = True diff --git a/server/llm_engine_server/domain/exceptions.py b/server/llm_engine_server/domain/exceptions.py deleted file mode 100644 index 4763690a..00000000 --- a/server/llm_engine_server/domain/exceptions.py +++ /dev/null @@ -1,83 +0,0 @@ -from llm_engine_server.core.domain_exceptions import DomainException - - -class ExistingEndpointOperationInProgressException(DomainException): - """ - Thrown when a user tries to edit an endpoint that has an edit in progress - """ - - def __init__(self, message): - self.message = message - - -class EndpointDeleteFailedException(DomainException): - """ - Thrown if the server failed to delete an endpoint for whatever reason. Indicates a bug serverside - """ - - -class EndpointUnsupportedInferenceTypeException(DomainException): - """ - Thrown if the requested inference type is unsupported by the endpoint. - """ - - -class EndpointResourceInvalidRequestException(DomainException): - """ - Thrown if the endpoint resource requests are invalid. - """ - - -class EndpointInfraStateNotFound(DomainException): - """ - Thrown if the endpoint infra_state field is expected to be not None but found to be None. - """ - - -class EndpointResourceInfraException(DomainException): - """ - Thrown if the endpoint resource request passes validation, but failed for unhandled reasons. - This corresponds to a 503 error and requires investigation by the LLMEngine team. - """ - - -class EndpointLabelsException(DomainException): - """ - Thrown if the endpoint required labels are missing or wrong. - """ - - -class TooManyRequestsException(DomainException): - """ - Thrown if an endpoint returns a 429 exception for too many requests. - """ - - -class CorruptRecordInfraStateException(DomainException): - """ - Thrown if the data from existing state (i.e. the db, k8s, etc.) is somehow uninterpretable - by the code. This can occur if the state isn't being written to correctly, if we've missed - a migration somewhere, etc. - """ - - -class UpstreamServiceError(DomainException): - """ - Thrown to relay an upstream HTTP service error to the user. - """ - - def __init__(self, status_code: int, content: bytes): - self.status_code = status_code - self.content = content - - -class LLMFineTuningMethodNotImplementedException(DomainException): - """ - Thrown if the requested fine-tuning model/method pair is not implemented. - """ - - -class InvalidRequestException(DomainException): - """ - Thrown if the user request is invalid. - """ diff --git a/server/llm_engine_server/domain/gateways/monitoring_metrics_gateway.py b/server/llm_engine_server/domain/gateways/monitoring_metrics_gateway.py deleted file mode 100644 index 33ff03c2..00000000 --- a/server/llm_engine_server/domain/gateways/monitoring_metrics_gateway.py +++ /dev/null @@ -1,46 +0,0 @@ -""" -For emitting external monitoring metrics to some sort of store e.g. datadog -Currently distinct from something emitting to a Metrics Store - -Used to calculate proportion of successful/unsuccessful requests, differentiates between -docker build vs other errors -""" - -from abc import ABC, abstractmethod - - -class MonitoringMetricsGateway(ABC): - @abstractmethod - def emit_attempted_build_metric(self): - """ - Service builder attempted metric - - """ - - @abstractmethod - def emit_successful_build_metric(self): - """ - Service builder succeeded metric - - """ - - @abstractmethod - def emit_docker_failed_build_metric(self): - """ - Service builder docker build failed metric - - """ - - @abstractmethod - def emit_database_cache_hit_metric(self): - """ - Successful database cache metric - - """ - - @abstractmethod - def emit_database_cache_miss_metric(self): - """ - Missed database cache metric - - """ diff --git a/server/llm_engine_server/domain/services/llm_fine_tuning_service.py b/server/llm_engine_server/domain/services/llm_fine_tuning_service.py deleted file mode 100644 index 0f71592e..00000000 --- a/server/llm_engine_server/domain/services/llm_fine_tuning_service.py +++ /dev/null @@ -1,30 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Dict - - -class LLMFineTuningService(ABC): - @abstractmethod - async def create_fine_tune_job( - self, - created_by: str, - owner: str, - training_file: str, - validation_file: str, - model_name: str, - base_model: str, - fine_tuning_method: str, - hyperparameters: Dict[str, str], - ): - pass - - @abstractmethod - async def get_fine_tune_job(self, owner: str, fine_tune_id: str): - pass - - @abstractmethod - async def list_fine_tune_jobs(self, owner: str): - pass - - @abstractmethod - async def cancel_fine_tune_job(self, owner: str, fine_tune_id: str): - pass diff --git a/server/llm_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py b/server/llm_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py deleted file mode 100644 index 34ef2172..00000000 --- a/server/llm_engine_server/domain/use_cases/llm_fine_tuning_use_cases.py +++ /dev/null @@ -1,82 +0,0 @@ -from llm_engine_server.common.dtos.llms import ( - CancelFineTuneJobResponse, - CreateFineTuneJobRequest, - CreateFineTuneJobResponse, - GetFineTuneJobResponse, - ListFineTuneJobResponse, -) -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ObjectNotFoundException -from llm_engine_server.infra.services import DockerImageBatchJobLLMFineTuningService - - -class CreateFineTuneJobV1UseCase: - def __init__(self, llm_fine_tuning_service: DockerImageBatchJobLLMFineTuningService): - self.llm_fine_tuning_service = llm_fine_tuning_service - - async def execute( - self, user: User, request: CreateFineTuneJobRequest - ) -> CreateFineTuneJobResponse: - fine_tune_id = await self.llm_fine_tuning_service.create_fine_tune_job( - created_by=user.user_id, - owner=user.team_id, - training_file=request.training_file, - validation_file=request.validation_file, - model_name=request.model_name, - base_model=request.base_model, - fine_tuning_method=request.fine_tuning_method, - hyperparameters=request.hyperparameters, - ) - return CreateFineTuneJobResponse( - fine_tune_id=fine_tune_id, - ) - - -class GetFineTuneJobV1UseCase: - def __init__(self, llm_fine_tuning_service: DockerImageBatchJobLLMFineTuningService): - self.llm_fine_tuning_service = llm_fine_tuning_service - - async def execute(self, user: User, fine_tune_id: str) -> GetFineTuneJobResponse: - di_batch_job = await self.llm_fine_tuning_service.get_fine_tune_job( - owner=user.team_id, - fine_tune_id=fine_tune_id, - ) - if di_batch_job is None: - raise ObjectNotFoundException - return GetFineTuneJobResponse( - fine_tune_id=di_batch_job.id, - status=di_batch_job.status, - ) - - -class ListFineTuneJobV1UseCase: - def __init__(self, llm_fine_tuning_service: DockerImageBatchJobLLMFineTuningService): - self.llm_fine_tuning_service = llm_fine_tuning_service - - async def execute(self, user: User) -> ListFineTuneJobResponse: - di_batch_jobs = await self.llm_fine_tuning_service.list_fine_tune_jobs( - owner=user.team_id, - ) - return ListFineTuneJobResponse( - jobs=[ - GetFineTuneJobResponse( - fine_tune_id=job.id, - status=job.status, - ) - for job in di_batch_jobs - ] - ) - - -class CancelFineTuneJobV1UseCase: - def __init__(self, llm_fine_tuning_service: DockerImageBatchJobLLMFineTuningService): - self.llm_fine_tuning_service = llm_fine_tuning_service - - async def execute(self, user: User, fine_tune_id: str) -> CancelFineTuneJobResponse: - success = await self.llm_fine_tuning_service.cancel_fine_tune_job( - owner=user.team_id, - fine_tune_id=fine_tune_id, - ) - return CancelFineTuneJobResponse( - success=success, - ) diff --git a/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py b/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py deleted file mode 100644 index 482a4519..00000000 --- a/server/llm_engine_server/domain/use_cases/llm_model_endpoint_use_cases.py +++ /dev/null @@ -1,811 +0,0 @@ -import json -from dataclasses import asdict -from typing import Any, AsyncIterable, Dict, Optional - -from llm_engine_server.common.dtos.llms import ( - CompletionOutput, - CompletionStreamOutput, - CompletionStreamV1Request, - CompletionStreamV1Response, - CompletionSyncV1Request, - CompletionSyncV1Response, - CreateLLMModelEndpointV1Request, - CreateLLMModelEndpointV1Response, - GetLLMModelEndpointV1Response, - ListLLMModelEndpointsV1Response, -) -from llm_engine_server.common.dtos.model_bundles import CreateModelBundleV2Request -from llm_engine_server.common.dtos.model_endpoints import ModelEndpointOrderBy -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request, TaskStatus -from llm_engine_server.common.resource_limits import validate_resource_requests -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( - ObjectHasInvalidValueException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.authorization.scale_authorization_module import ( - ScaleAuthorizationModule, -) -from llm_engine_server.domain.entities import ( - LLMInferenceFramework, - LLMMetadata, - LLMSource, - ModelBundle, - ModelBundleFlavorType, - ModelEndpoint, - ModelEndpointType, - Quantization, - RunnableImageFlavor, - StreamingEnhancedRunnableImageFlavor, -) -from llm_engine_server.domain.exceptions import ( - EndpointLabelsException, - EndpointUnsupportedInferenceTypeException, -) -from llm_engine_server.domain.repositories import ModelBundleRepository -from llm_engine_server.domain.services import LLMModelEndpointService, ModelEndpointService - -from .model_bundle_use_cases import CreateModelBundleV2UseCase -from .model_endpoint_use_cases import ( - _handle_post_inference_hooks, - model_endpoint_entity_to_get_model_endpoint_response, - validate_deployment_resources, - validate_post_inference_hooks, -) - -logger = make_logger(filename_wo_ext(__name__)) - -_SUPPORTED_MODEL_NAMES = { - LLMInferenceFramework.DEEPSPEED: { - "mpt-7b": "mosaicml/mpt-7b", - "mpt-7b-instruct": "mosaicml/mpt-7b-instruct", - "gpt-j-6b": "EleutherAI/gpt-j-6b", - "gpt-j-6b-zh-en": "EleutherAI/gpt-j-6b", - "gpt4all-j": "nomic-ai/gpt4all-j", - "dolly-v2-12b": "databricks/dolly-v2-12b", - "stablelm-tuned-7b": "StabilityAI/stablelm-tuned-alpha-7b", - "flan-t5-xxl": "google/flan-t5-xxl", - "llama-7b": "decapoda-research/llama-7b-hf", - "vicuna-13b": "eachadea/vicuna-13b-1.1", - }, - LLMInferenceFramework.TEXT_GENERATION_INFERENCE: { - "mpt-7b": "mosaicml/mpt-7b", - "mpt-7b-instruct": "mosaicml/mpt-7b-instruct", - "flan-t5-xxl": "google/flan-t5-xxl", - "llama-7b": "decapoda-research/llama-7b-hf", - "falcon-7b": "tiiuae/falcon-7b", - "falcon-7b-instruct": "tiiuae/falcon-7b-instruct", - "falcon-40b": "tiiuae/falcon-40b", - "falcon-40b-instruct": "tiiuae/falcon-40b-instruct", - }, -} - - -def _model_endpoint_entity_to_get_llm_model_endpoint_response( - model_endpoint: ModelEndpoint, -) -> GetLLMModelEndpointV1Response: - if model_endpoint.record.metadata is None or "_llm" not in model_endpoint.record.metadata: - raise ObjectHasInvalidValueException( - f"Can't translate model entity to response, endpoint {model_endpoint.record.id} does not have LLM metadata." - ) - llm_metadata = model_endpoint.record.metadata.get("_llm", {}) - response = GetLLMModelEndpointV1Response( - id=model_endpoint.record.id, - name=model_endpoint.record.name, - model_name=llm_metadata["model_name"], - source=llm_metadata["source"], - inference_framework=llm_metadata["inference_framework"], - inference_framework_image_tag=llm_metadata["inference_framework_image_tag"], - num_shards=llm_metadata["num_shards"], - quantize=llm_metadata.get("quantize"), - spec=model_endpoint_entity_to_get_model_endpoint_response(model_endpoint), - ) - return response - - -def validate_model_name(model_name: str, inference_framework: LLMInferenceFramework) -> None: - if model_name not in _SUPPORTED_MODEL_NAMES[inference_framework]: - raise ObjectHasInvalidValueException( - f"Model name {model_name} is not supported for inference framework {inference_framework}." - ) - - -def validate_num_shards( - num_shards: int, inference_framework: LLMInferenceFramework, gpus: int -) -> None: - if inference_framework == LLMInferenceFramework.DEEPSPEED: - if num_shards <= 1: - raise ObjectHasInvalidValueException("DeepSpeed requires more than 1 GPU.") - if num_shards != gpus: - raise ObjectHasInvalidValueException( - f"DeepSpeed requires num shard {num_shards} to be the same as number of GPUs {gpus}." - ) - - -class CreateLLMModelEndpointV1UseCase: - def __init__( - self, - create_model_bundle_use_case: CreateModelBundleV2UseCase, - model_bundle_repository: ModelBundleRepository, - model_endpoint_service: ModelEndpointService, - ): - self.authz_module = ScaleAuthorizationModule() - self.create_model_bundle_use_case = create_model_bundle_use_case - self.model_bundle_repository = model_bundle_repository - self.model_endpoint_service = model_endpoint_service - - async def create_model_bundle( - self, - user: User, - endpoint_name: str, - model_name: str, - source: LLMSource, - framework: LLMInferenceFramework, - framework_image_tag: str, - endpoint_type: ModelEndpointType, - num_shards: int, - quantize: Optional[Quantization], - checkpoint_path: Optional[str], - ) -> ModelBundle: - if source == LLMSource.HUGGING_FACE: - if framework == LLMInferenceFramework.DEEPSPEED: - bundle_id = await self.create_deepspeed_bundle( - user, - model_name, - framework_image_tag, - endpoint_type, - endpoint_name, - ) - elif framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: - bundle_id = await self.create_text_generation_inference_bundle( - user, - model_name, - framework_image_tag, - endpoint_name, - num_shards, - quantize, - checkpoint_path, - ) - else: - raise ObjectHasInvalidValueException( - f"Framework {framework} is not supported for source {source}." - ) - else: - raise ObjectHasInvalidValueException(f"Source {source} is not supported.") - - model_bundle = await self.model_bundle_repository.get_model_bundle(bundle_id) - if model_bundle is None: - raise ObjectNotFoundException(f"Model bundle {bundle_id} was not found after creation.") - return model_bundle - - async def create_text_generation_inference_bundle( - self, - user: User, - model_name: str, - framework_image_tag: str, - endpoint_unique_name: str, - num_shards: int, - quantize: Optional[Quantization], - checkpoint_path: Optional[str], - ): - command = [] - if checkpoint_path is not None: - if checkpoint_path.startswith("s3://"): - command = ["bash", "launch_s3_model.sh", checkpoint_path, str(num_shards)] - if quantize: - command = command + [f"'--quantize {str(quantize)}'"] - else: - raise ObjectHasInvalidValueException( - f"Not able to load checkpoint path {checkpoint_path}." - ) - else: - hf_model_name = _SUPPORTED_MODEL_NAMES[LLMInferenceFramework.TEXT_GENERATION_INFERENCE][ - model_name - ] - - command = [ - "text-generation-launcher", - "--model-id", - hf_model_name, - "--num-shard", - str(num_shards), - "--port", - "5005", - "--hostname", - "::", - ] - if quantize: - command = command + ["--quantize", str(quantize)] - - return ( - await self.create_model_bundle_use_case.execute( - user, - CreateModelBundleV2Request( - name=endpoint_unique_name, - schema_location="TBA", - flavor=StreamingEnhancedRunnableImageFlavor( - flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, - repository="text-generation-inference", # TODO: let user choose repo - tag=framework_image_tag, - command=command, - streaming_command=command, - protocol="http", - readiness_initial_delay_seconds=60, - healthcheck_route="/health", - predict_route="/generate", - streaming_predict_route="/generate_stream", - env={}, - ), - metadata={}, - ), - ) - ).model_bundle_id - - async def create_deepspeed_bundle( - self, - user: User, - model_name: str, - framework_image_tag: str, - endpoint_type: ModelEndpointType, - endpoint_unique_name: str, - ): - if endpoint_type == ModelEndpointType.STREAMING: - command = [ - "dumb-init", - "--", - "ddtrace-run", - "run-streamer", - "--http", - "production_threads", - "--concurrency", - "1", - "--config", - "/install/spellbook/inference/service--spellbook_streaming_inference.yaml", - ] - return ( - await self.create_model_bundle_use_case.execute( - user, - CreateModelBundleV2Request( - name=endpoint_unique_name, - schema_location="TBA", - flavor=StreamingEnhancedRunnableImageFlavor( - flavor=ModelBundleFlavorType.STREAMING_ENHANCED_RUNNABLE_IMAGE, - repository="instant-llm", # TODO: let user choose repo - tag=framework_image_tag, - command=command, - streaming_command=command, - env={ - "MODEL_NAME": model_name, - }, - protocol="http", - readiness_initial_delay_seconds=60, - ), - metadata={}, - ), - ) - ).model_bundle_id - else: - return ( - await self.create_model_bundle_use_case.execute( - user, - CreateModelBundleV2Request( - name=endpoint_unique_name, - schema_location="TBA", - flavor=RunnableImageFlavor( - flavor=ModelBundleFlavorType.RUNNABLE_IMAGE, - repository="instant-llm", - tag=framework_image_tag, - command=[ - "dumb-init", - "--", - "ddtrace-run", - "run-service", - "--http", - "production_threads", - "--concurrency", - "1", - "--config", - "/install/spellbook/inference/service--spellbook_inference.yaml", - ], - env={ - "MODEL_NAME": model_name, - }, - protocol="http", - readiness_initial_delay_seconds=1800, - ), - metadata={}, - ), - ) - ).model_bundle_id - - async def execute( - self, user: User, request: CreateLLMModelEndpointV1Request - ) -> CreateLLMModelEndpointV1Response: - validate_deployment_resources( - min_workers=request.min_workers, - max_workers=request.max_workers, - endpoint_type=request.endpoint_type, - ) - if request.labels is None: - raise EndpointLabelsException("Endpoint labels cannot be None!") - validate_post_inference_hooks(user, request.post_inference_hooks) - validate_model_name(request.model_name, request.inference_framework) - validate_num_shards(request.num_shards, request.inference_framework, request.gpus) - - if request.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: - if request.endpoint_type != ModelEndpointType.STREAMING: - raise ObjectHasInvalidValueException( - f"Creating endpoint type {str(request.endpoint_type)} is not allowed. Can only create streaming endpoints for text-generation-inference." - ) - - bundle = await self.create_model_bundle( - user, - endpoint_name=request.name, - model_name=request.model_name, - source=request.source, - framework=request.inference_framework, - framework_image_tag=request.inference_framework_image_tag, - endpoint_type=request.endpoint_type, - num_shards=request.num_shards, - quantize=request.quantize, - checkpoint_path=request.checkpoint_path, - ) - validate_resource_requests( - bundle=bundle, - cpus=request.cpus, - memory=request.memory, - storage=request.storage, - gpus=request.gpus, - gpu_type=request.gpu_type, - ) - - prewarm = request.prewarm - if prewarm is None: - prewarm = True - - high_priority = request.high_priority - if high_priority is None: - high_priority = False - - aws_role = self.authz_module.get_aws_role_for_user(user) - results_s3_bucket = self.authz_module.get_s3_bucket_for_user(user) - - request.metadata["_llm"] = asdict( - LLMMetadata( - model_name=request.model_name, - source=request.source, - inference_framework=request.inference_framework, - inference_framework_image_tag=request.inference_framework_image_tag, - num_shards=request.num_shards, - quantize=request.quantize, - ) - ) - - model_endpoint_record = await self.model_endpoint_service.create_model_endpoint( - name=request.name, - created_by=user.user_id, - model_bundle_id=bundle.id, - endpoint_type=request.endpoint_type, - metadata=request.metadata, - post_inference_hooks=request.post_inference_hooks, - child_fn_info=None, - cpus=request.cpus, - gpus=request.gpus, - memory=request.memory, - gpu_type=request.gpu_type, - storage=request.storage, - optimize_costs=bool(request.optimize_costs), - min_workers=request.min_workers, - max_workers=request.max_workers, - per_worker=request.per_worker, - labels=request.labels, - aws_role=aws_role, - results_s3_bucket=results_s3_bucket, - prewarm=prewarm, - high_priority=high_priority, - owner=user.team_id, - default_callback_url=request.default_callback_url, - default_callback_auth=request.default_callback_auth, - public_inference=request.public_inference, - ) - _handle_post_inference_hooks( - created_by=user.user_id, - name=request.name, - post_inference_hooks=request.post_inference_hooks, - ) - - return CreateLLMModelEndpointV1Response( - endpoint_creation_task_id=model_endpoint_record.creation_task_id # type: ignore - ) - - -class ListLLMModelEndpointsV1UseCase: - """ - Use case for listing all LLM Model Endpoint of a given user and model endpoint name. - Also include public_inference LLM endpoints. - """ - - def __init__(self, llm_model_endpoint_service: LLMModelEndpointService): - self.llm_model_endpoint_service = llm_model_endpoint_service - - async def execute( - self, user: User, name: Optional[str], order_by: Optional[ModelEndpointOrderBy] - ) -> ListLLMModelEndpointsV1Response: - """ - Runs the use case to list all Model Endpoints owned by the user with the given name. - - Args: - user: The owner of the model endpoint(s). - name: The name of the Model Endpoint(s). - order_by: An optional argument to specify the output ordering of the model endpoints. - - Returns: - A response object that contains the model endpoints. - """ - model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=name, order_by=order_by - ) - return ListLLMModelEndpointsV1Response( - model_endpoints=[ - _model_endpoint_entity_to_get_llm_model_endpoint_response(m) - for m in model_endpoints - ] - ) - - -class GetLLMModelEndpointByNameV1UseCase: - """ - Use case for getting an LLM Model Endpoint of a given user by name. - """ - - def __init__(self, llm_model_endpoint_service: LLMModelEndpointService): - self.llm_model_endpoint_service = llm_model_endpoint_service - self.authz_module = ScaleAuthorizationModule() - - async def execute(self, user: User, model_endpoint_name: str) -> GetLLMModelEndpointV1Response: - """ - Runs the use case to get the LLM endpoint with the given name. - - Args: - user: The owner of the model endpoint. - model_endpoint_name: The name of the model endpoint. - - Returns: - A response object that contains the model endpoint. - - Raises: - ObjectNotFoundException: If a model endpoint with the given name could not be found. - ObjectNotAuthorizedException: If the owner does not own the model endpoint. - """ - model_endpoint = await self.llm_model_endpoint_service.get_llm_model_endpoint( - model_endpoint_name - ) - if not model_endpoint: - raise ObjectNotFoundException - if not self.authz_module.check_access_read_owned_entity( - user, model_endpoint.record - ) and not self.authz_module.check_endpoint_public_inference_for_user( - user, model_endpoint.record - ): - raise ObjectNotAuthorizedException - return _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) - - -class DeleteLLMModelEndpointByIdV1UseCase: - pass - - -class CompletionSyncV1UseCase: - """ - Use case for running a prompt completion on an LLM endpoint. - """ - - def __init__( - self, - model_endpoint_service: ModelEndpointService, - llm_model_endpoint_service: LLMModelEndpointService, - ): - self.model_endpoint_service = model_endpoint_service - self.llm_model_endpoint_service = llm_model_endpoint_service - self.authz_module = ScaleAuthorizationModule() - - def model_output_to_completion_output( - self, - model_output: Dict[str, Any], - model_endpoint: ModelEndpoint, - ) -> CompletionOutput: - model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) - - if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: - completion_token_count = len(model_output["token_probs"]["tokens"]) - return CompletionOutput( - text=model_output["text"], - num_completion_tokens=completion_token_count, - ) - elif model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: - return CompletionOutput( - text=model_output["generated_text"], - # len(model_output["details"]["prefill"]) does not return the correct value reliably - num_completion_tokens=model_output["details"]["generated_tokens"], - ) - else: - raise EndpointUnsupportedInferenceTypeException( - f"Unsupported inference framework {model_content.inference_framework}" - ) - - async def execute( - self, user: User, model_endpoint_name: str, request: CompletionSyncV1Request - ) -> CompletionSyncV1Response: - """ - Runs the use case to create a sync inference task. - - Args: - user: The user who is creating the sync inference task. - model_endpoint_name: The name of the model endpoint for the task. - request: The body of the request to forward to the endpoint. - - Returns: - A response object that contains the status and result of the task. - - Raises: - ObjectNotFoundException: If a model endpoint with the given name could not be found. - ObjectNotAuthorizedException: If the owner does not own the model endpoint. - """ - - model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None - ) - - if len(model_endpoints) == 0: - raise ObjectNotFoundException - - if len(model_endpoints) > 1: - raise ObjectHasInvalidValueException( - f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" - ) - - model_endpoint = model_endpoints[0] - - if not self.authz_module.check_access_read_owned_entity( - user, model_endpoint.record - ) and not self.authz_module.check_endpoint_public_inference_for_user( - user, model_endpoint.record - ): - raise ObjectNotAuthorizedException - - if ( - model_endpoint.record.endpoint_type is not ModelEndpointType.SYNC - and model_endpoint.record.endpoint_type is not ModelEndpointType.STREAMING - ): - raise EndpointUnsupportedInferenceTypeException( - f"Endpoint {model_endpoint_name} does not serve sync requests." - ) - - inference_gateway = self.model_endpoint_service.get_sync_model_endpoint_inference_gateway() - endpoint_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) - if endpoint_content.inference_framework == LLMInferenceFramework.DEEPSPEED: - args: Any = { - "prompts": request.prompts, - "token_probs": True, - "generate_kwargs": { - "do_sample": True, - "temperature": request.temperature, - "max_new_tokens": request.max_new_tokens, - }, - "serialize_results_as_string": False, - } - - inference_request = EndpointPredictV1Request(args=args) - predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request - ) - - if predict_result.status == TaskStatus.SUCCESS and predict_result.result is not None: - return CompletionSyncV1Response( - status=predict_result.status, - outputs=[ - self.model_output_to_completion_output(result, model_endpoint) - for result in predict_result.result["result"] - ], - ) - else: - return CompletionSyncV1Response( - status=predict_result.status, - outputs=[], - traceback=predict_result.traceback, - ) - elif ( - endpoint_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE - ): - outputs = [] - - for prompt in request.prompts: - tgi_args: Any = { - "inputs": prompt, - "parameters": { - "max_new_tokens": request.max_new_tokens, - "temperature": request.temperature, - "decoder_input_details": True, - }, - } - inference_request = EndpointPredictV1Request(args=tgi_args) - predict_result = await inference_gateway.predict( - topic=model_endpoint.record.destination, predict_request=inference_request - ) - - if predict_result.status != TaskStatus.SUCCESS or predict_result.result is None: - return CompletionSyncV1Response( - status=predict_result.status, - outputs=[], - traceback=predict_result.traceback, - ) - - outputs.append(json.loads(predict_result.result["result"])) - - return CompletionSyncV1Response( - status=predict_result.status, - outputs=[ - self.model_output_to_completion_output(output, model_endpoint) - for output in outputs - ], - ) - else: - raise EndpointUnsupportedInferenceTypeException( - f"Unsupported inference framework {endpoint_content.inference_framework}" - ) - - -class CompletionStreamV1UseCase: - """ - Use case for running a stream prompt completion on an LLM endpoint. - """ - - def __init__( - self, - model_endpoint_service: ModelEndpointService, - llm_model_endpoint_service: LLMModelEndpointService, - ): - self.model_endpoint_service = model_endpoint_service - self.llm_model_endpoint_service = llm_model_endpoint_service - self.authz_module = ScaleAuthorizationModule() - - async def execute( - self, user: User, model_endpoint_name: str, request: CompletionStreamV1Request - ) -> AsyncIterable[CompletionStreamV1Response]: - """ - Runs the use case to create a stream inference task. - - Args: - user: The user who is creating the stream inference task. - model_endpoint_name: The name of the model endpoint for the task. - request: The body of the request to forward to the endpoint. - - Returns: - A response object that contains the status and result of the task. - - Raises: - ObjectNotFoundException: If a model endpoint with the given name could not be found. - ObjectNotAuthorizedException: If the owner does not own the model endpoint. - """ - - model_endpoints = await self.llm_model_endpoint_service.list_llm_model_endpoints( - owner=user.team_id, name=model_endpoint_name, order_by=None - ) - - if len(model_endpoints) == 0: - raise ObjectNotFoundException - - if len(model_endpoints) > 1: - raise ObjectHasInvalidValueException( - f"Expected 1 LLM model endpoint for model name {model_endpoint_name}, got {len(model_endpoints)}" - ) - - model_endpoint = model_endpoints[0] - - if not self.authz_module.check_access_read_owned_entity( - user, model_endpoint.record - ) and not self.authz_module.check_endpoint_public_inference_for_user( - user, model_endpoint.record - ): - raise ObjectNotAuthorizedException - - if model_endpoint.record.endpoint_type != ModelEndpointType.STREAMING: - raise EndpointUnsupportedInferenceTypeException( - f"Endpoint {model_endpoint_name} is not a streaming endpoint." - ) - - inference_gateway = ( - self.model_endpoint_service.get_streaming_model_endpoint_inference_gateway() - ) - - model_content = _model_endpoint_entity_to_get_llm_model_endpoint_response(model_endpoint) - - args: Any = None - if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: - args = { - "prompts": [request.prompt], - "token_probs": True, - "generate_kwargs": { - "do_sample": True, - "temperature": request.temperature, - "max_new_tokens": request.max_new_tokens, - }, - "serialize_results_as_string": False, - } - elif model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE: - args = { - "inputs": request.prompt, - "parameters": { - "max_new_tokens": request.max_new_tokens, - "temperature": request.temperature, - }, - } - inference_request = EndpointPredictV1Request(args=args) - - predict_result = inference_gateway.streaming_predict( - topic=model_endpoint.record.destination, predict_request=inference_request - ) - - num_completion_tokens = 0 - async for res in predict_result: - result = res.result - if model_content.inference_framework == LLMInferenceFramework.DEEPSPEED: - if res.status == TaskStatus.SUCCESS and result is not None: - if "token" in result["result"]: - yield CompletionStreamV1Response( - status=res.status, - output=CompletionStreamOutput( - text=result["result"]["token"], - finished=False, - num_completion_tokens=None, - ), - ) - else: - completion_token_count = len( - result["result"]["response"][0]["token_probs"]["tokens"] - ) - yield CompletionStreamV1Response( - status=res.status, - output=CompletionStreamOutput( - text=result["result"]["response"][0]["text"], - finished=True, - num_completion_tokens=completion_token_count, - ), - ) - else: - yield CompletionStreamV1Response( - status=res.status, - output=None, - traceback=res.traceback, - ) - elif ( - model_content.inference_framework == LLMInferenceFramework.TEXT_GENERATION_INFERENCE - ): - if res.status == TaskStatus.SUCCESS and result is not None: - if result["result"].get("generated_text") is not None: - finished = True - else: - finished = False - - num_completion_tokens += 1 - - yield CompletionStreamV1Response( - status=res.status, - output=CompletionStreamOutput( - text=result["result"]["token"]["text"], - finished=finished, - num_completion_tokens=num_completion_tokens, - ), - ) - else: - yield CompletionStreamV1Response( - status=res.status, - output=None, - traceback=res.traceback, - ) - else: - raise EndpointUnsupportedInferenceTypeException( - f"Unsupported inference framework {model_content.inference_framework}" - ) diff --git a/server/llm_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml b/server/llm_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml deleted file mode 100644 index 84cc7e9c..00000000 --- a/server/llm_engine_server/inference/configs/service--forwarder-runnable-img-converted-from-artifact.yaml +++ /dev/null @@ -1,13 +0,0 @@ -forwarder: - model: - class_name: llm_engine.inference.forwarding.forwarding.LoadForwarder - args: - user_port: 5005 - user_hostname: "localhost" - use_grpc: false - predict_route: "/predict" - healthcheck_route: "/readyz" - batch_route: null - llm_engine_unwrap: false - serialize_results_as_string: false - wrap_response: false \ No newline at end of file diff --git a/server/llm_engine_server/inference/configs/service--forwarder.yaml b/server/llm_engine_server/inference/configs/service--forwarder.yaml deleted file mode 100644 index 3b7ff30e..00000000 --- a/server/llm_engine_server/inference/configs/service--forwarder.yaml +++ /dev/null @@ -1,13 +0,0 @@ -forwarder: - model: - class_name: llm_engine.inference.forwarding.forwarding.LoadForwarder - args: - user_port: 5005 - user_hostname: "localhost" - use_grpc: false - predict_route: "/predict" - healthcheck_route: "/readyz" - batch_route: null - llm_engine_unwrap: true - serialize_results_as_string: true - diff --git a/server/llm_engine_server/inference/configs/service--streaming_forwarder.yaml b/server/llm_engine_server/inference/configs/service--streaming_forwarder.yaml deleted file mode 100644 index 23f844e4..00000000 --- a/server/llm_engine_server/inference/configs/service--streaming_forwarder.yaml +++ /dev/null @@ -1,10 +0,0 @@ -forwarder: - class_name: llm_engine.inference.forwarding.forwarding.LoadStreamingForwarder - args: - user_port: 5005 - user_hostname: "localhost" - predict_route: "/stream" - healthcheck_route: "/readyz" - batch_route: null - llm_engine_unwrap: true - serialize_results_as_string: false diff --git a/server/llm_engine_server/inference/forwarding/http_forwarder.py b/server/llm_engine_server/inference/forwarding/http_forwarder.py deleted file mode 100644 index efac1752..00000000 --- a/server/llm_engine_server/inference/forwarding/http_forwarder.py +++ /dev/null @@ -1,167 +0,0 @@ -import argparse -import json -import os -import subprocess -from functools import lru_cache -from typing import Any, List - -import yaml -from fastapi import Depends, FastAPI -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.loggers import logger_name, make_logger -from llm_engine_server.inference.forwarding.forwarding import LoadForwarder, LoadStreamingForwarder -from sse_starlette.sse import EventSourceResponse - -logger = make_logger(logger_name()) -app = FastAPI() - - -def _set_value(config: dict, key_path: List[str], value: Any) -> None: - """ - Modifies config by setting the value at config[key_path[0]][key_path[1]]... to be `value`. - """ - key = key_path[0] - if len(key_path) == 1: - config[key] = value - else: - if key not in config: - config[key] = dict() - _set_value(config[key], key_path[1:], value) - - -def _substitute_config_overrides(config: dict, config_overrides: List[str]) -> None: - """ - Modifies config based on config_overrides. - - config_overrides should be a list of strings of the form `key=value`, - where `key` can be of the form `key1.key2` to denote a substitution for config[key1][key2] - (nesting can be arbitrarily deep). - """ - for override in config_overrides: - split = override.split("=") - if len(split) != 2: - raise ValueError(f"Config override {override} must contain exactly one =") - key_path, value = split - try: - _set_value(config, key_path.split("."), value) - except Exception as e: - raise ValueError(f"Error setting {key_path} to {value} in {config}") from e - - -def _load_named_config(config_uri, config_overrides=None): - with open(config_uri, "rt") as rt: - if config_uri.endswith(".json"): - return json.load(rt) - else: - c = yaml.safe_load(rt) - if config_overrides: - _substitute_config_overrides(c, config_overrides) - if len(c) == 1: - name = list(c.keys())[0] - c = c[name] - if "name" not in c: - c["name"] = name - return c - - -@app.get("/healthz") -@app.get("/readyz") -def healthcheck(): - return "OK" - - -def get_config(): - overrides = os.getenv("CONFIG_OVERRIDES") - config_overrides = None - if overrides is not None: - config_overrides = overrides.split(";") - return _load_named_config( - os.getenv("CONFIG_FILE"), - config_overrides, - ) - - -def get_forwarder_loader(): - config = get_config() - forwarder_loader = LoadForwarder(**config["sync"]) - return forwarder_loader - - -def get_streaming_forwarder_loader(): - config = get_config() - streaming_forwarder_loader = LoadStreamingForwarder(**config["stream"]) - return streaming_forwarder_loader - - -@lru_cache() -def load_forwarder(): - return get_forwarder_loader().load(None, None) - - -@lru_cache() -def load_streaming_forwarder(): - return get_streaming_forwarder_loader().load(None, None) - - -@app.post("/predict") -def predict(request: EndpointPredictV1Request, forwarder=Depends(load_forwarder)): - return forwarder(request.dict()) - - -@app.post("/stream") -async def stream(request: EndpointPredictV1Request, forwarder=Depends(load_streaming_forwarder)): - try: - payload = request.dict() - except Exception: - logger.error(f"Failed to decode payload from: {request}") - raise - else: - logger.debug(f"Received request: {payload}") - - # has internal error logging for each processing stage - responses = forwarder(payload) - - async def event_generator(): - for response in responses: - yield {"data": json.dumps(response)} - - return EventSourceResponse(event_generator()) - - -def entrypoint(): - parser = argparse.ArgumentParser() - parser.add_argument("--config", type=str, required=True) - parser.add_argument("--num-workers", type=int, required=True) - parser.add_argument("--host", type=str, default="[::]") - parser.add_argument("--port", type=int, default=5000) - parser.add_argument("--set", type=str, action="append") - - args = parser.parse_args() - - values = [f"CONFIG_FILE={args.config}"] - if args.set is not None: - values.append(f"CONFIG_OVERRIDES={';'.join(args.set)}") - envs = [] - for v in values: - envs.extend(["--env", v]) - - command = [ - "gunicorn", - "--bind", - f"{args.host}:{args.port}", - "--timeout", - "1200", - "--keep-alive", - "2", - "--worker-class", - "uvicorn.workers.UvicornWorker", - "--workers", - str(args.num_workers), - *envs, - "llm_engine_server.inference.forwarding.http_forwarder:app", - ] - subprocess.run(command) - - -if __name__ == "__main__": - entrypoint() diff --git a/server/llm_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py b/server/llm_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py deleted file mode 100644 index 8e7d3aa9..00000000 --- a/server/llm_engine_server/inference/infra/gateways/datadog_inference_monitoring_metrics_gateway.py +++ /dev/null @@ -1,12 +0,0 @@ -from datadog import statsd -from llm_engine_server.inference.domain.gateways.inference_monitoring_metrics_gateway import ( - InferenceMonitoringMetricsGateway, -) - - -class DatadogInferenceMonitoringMetricsGateway(InferenceMonitoringMetricsGateway): - def emit_attempted_post_inference_hook(self, hook: str): - statsd.increment(f"scale_llm_engine_server.post_inference_hook.{hook}.attempt") - - def emit_successful_post_inference_hook(self, hook: str): - statsd.increment(f"scale_llm_engine_server.post_inference_hook.{hook}.success") diff --git a/server/llm_engine_server/inference/limits.conf b/server/llm_engine_server/inference/limits.conf deleted file mode 100644 index a22a6bc1..00000000 --- a/server/llm_engine_server/inference/limits.conf +++ /dev/null @@ -1,2 +0,0 @@ -llmengine hard nproc 2000 -llmengine soft nproc 1000 diff --git a/server/llm_engine_server/inference/post_inference_hooks.py b/server/llm_engine_server/inference/post_inference_hooks.py deleted file mode 100644 index 626bd1b5..00000000 --- a/server/llm_engine_server/inference/post_inference_hooks.py +++ /dev/null @@ -1,123 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional - -import requests -from llm_engine_server.common.constants import CALLBACK_POST_INFERENCE_HOOK -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import CallbackAuth, CallbackBasicAuth -from llm_engine_server.inference.common import _write_to_s3 -from llm_engine_server.inference.domain.gateways.inference_monitoring_metrics_gateway import ( - InferenceMonitoringMetricsGateway, -) -from tenacity import Retrying, stop_after_attempt, wait_exponential - -logger = make_logger(filename_wo_ext(__file__)) - - -def _upload_data(data: Any): - return _write_to_s3(data).get("result_url") - - -class PostInferenceHook(ABC): - def __init__( - self, - endpoint_name: str, - bundle_name: str, - user_id: str, - ): - self._endpoint_name = endpoint_name - self._bundle_name = bundle_name - self._user_id = user_id - - @abstractmethod - def handle( - self, - request_payload: EndpointPredictV1Request, - response: Dict[str, Any], - task_id: Optional[str], - ): - pass - - -class CallbackHook(PostInferenceHook): - def __init__( - self, - endpoint_name: str, - bundle_name: str, - user_id: str, - default_callback_url: Optional[str], - default_callback_auth: Optional[CallbackAuth], - ): - super().__init__(endpoint_name, bundle_name, user_id) - self._default_callback_url = default_callback_url - self._default_callback_auth = default_callback_auth - - def handle( - self, - request_payload: EndpointPredictV1Request, - response: Dict[str, Any], - task_id: Optional[str], - ): - callback_url = request_payload.callback_url - if not callback_url: - callback_url = self._default_callback_url - if not callback_url: - logger.warning("No callback URL specified for request.") - return - - response["task_id"] = task_id - auth = request_payload.callback_auth or self._default_callback_auth - if auth and isinstance(auth.__root__, CallbackBasicAuth): - auth_tuple = (auth.__root__.username, auth.__root__.password) - else: - auth_tuple = (self._user_id, "") - - for attempt in Retrying(stop=stop_after_attempt(3), wait=wait_exponential()): - with attempt: - res = requests.post(url=callback_url, json=response, auth=auth_tuple) - assert 200 <= res.status_code < 300 - - -class PostInferenceHooksHandler: - def __init__( - self, - endpoint_name: str, - bundle_name: str, - user_id: str, - default_callback_url: Optional[str], - default_callback_auth: Optional[CallbackAuth], - post_inference_hooks: Optional[List[str]], - monitoring_metrics_gateway: InferenceMonitoringMetricsGateway, - ): - self._monitoring_metrics_gateway = monitoring_metrics_gateway - self._hooks: Dict[str, PostInferenceHook] = {} - if post_inference_hooks: - for hook in post_inference_hooks: - # TODO: Ensure that this process gracefully handles errors in - # initializing each post-inference hook. - hook_lower = hook.lower() - if hook_lower == CALLBACK_POST_INFERENCE_HOOK: - self._hooks[hook_lower] = CallbackHook( - endpoint_name, - bundle_name, - user_id, - default_callback_url, - default_callback_auth, - ) - else: - raise ValueError(f"Hook {hook_lower} is currently not supported.") - - def handle( - self, - request_payload: EndpointPredictV1Request, - response: Dict[str, Any], - task_id: Optional[str] = None, - ): - for hook_name, hook in self._hooks.items(): - self._monitoring_metrics_gateway.emit_attempted_post_inference_hook(hook_name) - try: - hook.handle(request_payload, response, task_id) - self._monitoring_metrics_gateway.emit_successful_post_inference_hook(hook_name) - except Exception: - logger.exception(f"Hook {hook_name} failed.") diff --git a/server/llm_engine_server/inference/pytorch_or_tf.Dockerfile b/server/llm_engine_server/inference/pytorch_or_tf.Dockerfile deleted file mode 100644 index 999a0564..00000000 --- a/server/llm_engine_server/inference/pytorch_or_tf.Dockerfile +++ /dev/null @@ -1,81 +0,0 @@ -### THIS FILE IS DEPRECATED IN V1. INSTEAD, USE pytorch_or_tf.base.Dockerfile -### and pytorch_or_tf.user.Dockerfile -ARG BASE_IMAGE -FROM ${BASE_IMAGE} - -WORKDIR /app - -# Install basic packages. -# TODO: ffmpeg, libsm6, and lixext6 are essentially hardcoded from lidar. -# It's probably more correct to add support for arbitrary user-specified base images, -# otherwise this base image gets bloated over time. -RUN apt-get update && apt-get install -y \ - apt-utils \ - dumb-init \ - git \ - ssh \ - emacs-nox \ - htop \ - iftop \ - vim \ - ffmpeg \ - libsm6 \ - libxext6 \ - libcurl4-openssl-dev \ - libssl-dev \ - python3-dev \ - gcc \ - build-essential \ - && rm -rf /var/lib/apt/lists/* - -# Apparently wget has a vulnerability so we remove it here -RUN apt-get remove wget -y - -# Create a virtualenv for python so we install our packages in the right place -# Not sure how useful the existing contents of the pytorch image are anymore :/ Maybe it's used for cuda/cudnn installs -RUN python3 -m venv /venv -ENV PATH=/venv/bin:$PATH - -# Run everything as not-root user -RUN useradd -m llmengine -s /bin/bash -RUN chown -R llmengine /venv -RUN chown -R llmengine /app -# Limits for nproc and consequently number of files open -ADD llm_engine/llm_engine/inference/limits.conf /etc/security/limits.conf -USER llmengine - -RUN mkdir -p /app/ml_infra_core/llm_engine.core -RUN chown -R llmengine /app/ml_infra_core - -COPY --chown=llmengine ml_infra_core/llm_engine.core/requirements.txt ml_infra_core/llm_engine.core/requirements.txt -RUN PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf pip install -r ml_infra_core/llm_engine.core/requirements.txt --no-cache-dir -COPY --chown=llmengine ml_infra_core/llm_engine.core ml_infra_core/llm_engine.core -RUN pip install -e ml_infra_core/llm_engine.core - -# Not good for layer caching oh well -# The inference code should only need these few files/directories to function (hopefully) -# Don't copy the entire folder for security reasons - -RUN mkdir -p /app/llm_engine -RUN mkdir -p /app/llm_engine/llm_engine - -RUN chown -R llmengine /app/llm_engine - -COPY --chown=llmengine llm_engine/setup.py /app/llm_engine/setup.py -COPY --chown=llmengine llm_engine/llm_engine.egg-info /app/llm_engine/llm_engine.egg-info -COPY --chown=llmengine llm_engine/llm_engine/__init__.py /app/llm_engine/llm_engine/__init__.py -COPY --chown=llmengine llm_engine/llm_engine/common /app/llm_engine/llm_engine/common -COPY --chown=llmengine llm_engine/llm_engine/domain /app/llm_engine/llm_engine/domain -COPY --chown=llmengine llm_engine/llm_engine/infra /app/llm_engine/llm_engine/infra -COPY --chown=llmengine llm_engine/llm_engine/inference /app/llm_engine/llm_engine/inference -WORKDIR /app/llm_engine -RUN pip install -e . -WORKDIR /app - -RUN pip install -r /app/llm_engine/llm_engine/inference/requirements_base.txt -ARG REQUIREMENTS_FILE -COPY --chown=llmengine ${REQUIREMENTS_FILE} /app/llm_engine/llm_engine/inference/requirements.txt -RUN PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf pip install -r /app/llm_engine/llm_engine/inference/requirements.txt - - -ENV PYTHONPATH /app diff --git a/server/llm_engine_server/inference/pytorch_or_tf.base.Dockerfile b/server/llm_engine_server/inference/pytorch_or_tf.base.Dockerfile deleted file mode 100644 index 72b711ad..00000000 --- a/server/llm_engine_server/inference/pytorch_or_tf.base.Dockerfile +++ /dev/null @@ -1,78 +0,0 @@ -ARG BASE_IMAGE -FROM ${BASE_IMAGE} - -WORKDIR /app - -# Install basic packages. -# TODO: ffmpeg, libsm6, and lixext6 are essentially hardcoded from lidar. -# It's probably more correct to add support for arbitrary user-specified base images, -# otherwise this base image gets bloated over time. -RUN apt-get update && apt-get install -y \ - apt-utils \ - dumb-init \ - git \ - ssh \ - emacs-nox \ - htop \ - iftop \ - vim \ - ffmpeg \ - libsm6 \ - libxext6 \ - libcurl4-openssl-dev \ - libssl-dev \ - python3-dev \ - gcc \ - build-essential \ - && rm -rf /var/lib/apt/lists/* - -# Apparently wget has a vulnerability so we remove it here -RUN apt-get remove wget -y - -# Create a virtualenv for python so we install our packages in the right place -# Not sure how useful the existing contents of the pytorch image are anymore :/ Maybe it's used for cuda/cudnn installs -RUN python3 -m venv /venv -ENV PATH=/venv/bin:$PATH - -# Run everything as not-root user -RUN useradd -m llmengine -s /bin/bash -RUN chown -R llmengine /venv -RUN chown -R llmengine /app -# Limits for nproc and consequently number of files open -ADD llm_engine/llm_engine/inference/limits.conf /etc/security/limits.conf -USER llmengine - -RUN mkdir -p /app/ml_infra_core/llm_engine.core -RUN chown -R llmengine /app/ml_infra_core - -COPY --chown=llmengine ml_infra_core/llm_engine.core/requirements.txt ml_infra_core/llm_engine.core/requirements.txt -RUN --mount=type=secret,id=codeartifact-pip-conf,target=/etc/pip.conf,mode=0444 \ - PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf \ - pip install -r ml_infra_core/llm_engine.core/requirements.txt --no-cache-dir -COPY --chown=llmengine ml_infra_core/llm_engine.core ml_infra_core/llm_engine.core -RUN pip install -e ml_infra_core/llm_engine.core - -# Not good for layer caching oh well -# The inference code should only need these few files/directories to function (hopefully) -# Don't copy the entire folder for security reasons - -RUN mkdir -p /app/llm_engine -RUN mkdir -p /app/llm_engine/llm_engine - -RUN chown -R llmengine /app/llm_engine - -COPY --chown=llmengine \ - llm_engine/llm_engine/inference/requirements_base.txt \ - /app/llm_engine/llm_engine/inference/requirements_base.txt -RUN pip install -r /app/llm_engine/llm_engine/inference/requirements_base.txt - -COPY --chown=llmengine llm_engine/setup.py /app/llm_engine/setup.py -COPY --chown=llmengine llm_engine/llm_engine.egg-info /app/llm_engine/llm_engine.egg-info -COPY --chown=llmengine llm_engine/llm_engine/__init__.py /app/llm_engine/llm_engine/__init__.py -COPY --chown=llmengine llm_engine/llm_engine/common /app/llm_engine/llm_engine/common -COPY --chown=llmengine llm_engine/llm_engine/domain /app/llm_engine/llm_engine/domain -COPY --chown=llmengine llm_engine/llm_engine/infra /app/llm_engine/llm_engine/infra -COPY --chown=llmengine llm_engine/llm_engine/inference /app/llm_engine/llm_engine/inference -WORKDIR /app/llm_engine -RUN pip install -e . -WORKDIR /app diff --git a/server/llm_engine_server/inference/pytorch_or_tf.user.Dockerfile b/server/llm_engine_server/inference/pytorch_or_tf.user.Dockerfile deleted file mode 100644 index eb3c35df..00000000 --- a/server/llm_engine_server/inference/pytorch_or_tf.user.Dockerfile +++ /dev/null @@ -1,8 +0,0 @@ -ARG BASE_IMAGE -FROM ${BASE_IMAGE} - -ARG REQUIREMENTS_FILE -COPY --chown=llmengine ${REQUIREMENTS_FILE} /app/llm_engine/llm_engine/inference/requirements.txt -RUN PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf pip install -r /app/llm_engine/llm_engine/inference/requirements.txt - -ENV PYTHONPATH /app diff --git a/server/llm_engine_server/inference/requirements_base.txt b/server/llm_engine_server/inference/requirements_base.txt deleted file mode 100644 index 1d543b2d..00000000 --- a/server/llm_engine_server/inference/requirements_base.txt +++ /dev/null @@ -1,9 +0,0 @@ -fastapi==0.78.0 -uvicorn==0.17.6 -waitress==2.1.2 -smart_open==5.1.0 -# Pin typing-extensions so aioitertools doesn't break -typing-extensions>=4.1.1 -scale-launch>=0.1.0 -# Incompatibility between celery 5 and python 3.7 because of importlib-metadata 5, so we pin it -importlib-metadata<5.0;python_version<"3.8" diff --git a/server/llm_engine_server/inference/sync_inference/fastapi_server.py b/server/llm_engine_server/inference/sync_inference/fastapi_server.py deleted file mode 100644 index 78ec06d5..00000000 --- a/server/llm_engine_server/inference/sync_inference/fastapi_server.py +++ /dev/null @@ -1,106 +0,0 @@ -import traceback -from functools import wraps -from multiprocessing import BoundedSemaphore -from multiprocessing.synchronize import BoundedSemaphore as BoundedSemaphoreType -from typing import Optional - -from fastapi import BackgroundTasks, FastAPI, HTTPException, Response, status -from llm_engine_server.common.dtos.tasks import EndpointPredictV1Request -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.inference.common import ( - get_endpoint_config, - load_predict_fn_or_cls, - run_predict, -) -from llm_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( - DatadogInferenceMonitoringMetricsGateway, -) -from llm_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler -from llm_engine_server.inference.sync_inference.constants import ( - CONCURRENCY, - FAIL_ON_CONCURRENCY_LIMIT, - NAME, -) - -logger = make_logger(filename_wo_ext(__file__)) - - -class MultiprocessingConcurrencyLimiter: - # Shamelessly copied from std-ml-srv - def __init__(self, concurrency: Optional[int], fail_on_concurrency_limit: bool): - if concurrency is not None: - if concurrency < 1: - raise ValueError("Concurrency should be at least 1") - self.semaphore: Optional[BoundedSemaphoreType] = BoundedSemaphore(value=concurrency) - self.blocking = ( - not fail_on_concurrency_limit - ) # we want to block if we want to queue up requests - else: - self.semaphore = None - self.blocking = False # Unused - - def __enter__(self): - logger.debug("Entering concurrency limiter semaphore") - if self.semaphore and not self.semaphore.acquire(block=self.blocking): - logger.warning("Too many requests, returning 429") - raise HTTPException(status_code=429, detail="Too many requests") - # Just raises an HTTPException. - # __exit__ should not run; otherwise the release() doesn't have an acquire() - - def __exit__(self, type, value, traceback): - logger.debug("Exiting concurrency limiter semaphore") - if self.semaphore: - self.semaphore.release() - - -def with_concurrency_limit(concurrency_limiter: MultiprocessingConcurrencyLimiter): - # Shamelessly copied from std-ml-srv - def _inner(flask_func): - @wraps(flask_func) - def _inner_2(*args, **kwargs): - with concurrency_limiter: - return flask_func(*args, **kwargs) - - return _inner_2 - - return _inner - - -app = FastAPI(title=NAME) -concurrency_limiter = MultiprocessingConcurrencyLimiter(CONCURRENCY, FAIL_ON_CONCURRENCY_LIMIT) - -# How does this interact with threads? -# Analogous to init_worker() inside async_inference -predict_fn = load_predict_fn_or_cls() -endpoint_config = get_endpoint_config() -hooks = PostInferenceHooksHandler( - endpoint_name=endpoint_config.endpoint_name, - bundle_name=endpoint_config.bundle_name, - post_inference_hooks=endpoint_config.post_inference_hooks, - user_id=endpoint_config.user_id, - default_callback_url=endpoint_config.default_callback_url, - default_callback_auth=endpoint_config.default_callback_auth, - monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), -) - - -@app.get("/healthcheck") -@app.get("/healthz") -@app.get("/readyz") -def healthcheck(): - return Response(status_code=status.HTTP_200_OK) - - -@app.post("/predict") -@with_concurrency_limit(concurrency_limiter) -def predict(payload: EndpointPredictV1Request, background_tasks: BackgroundTasks): - """ - Assumption: payload is a JSON with format {"url": , "args": , "returned_pickled": boolean} - Returns: Results of running the predict function on the request url. See `run_predict`. - """ - try: - result = run_predict(predict_fn, payload) - background_tasks.add_task(hooks.handle, payload, result) - return result - except Exception: - raise HTTPException(status_code=500, detail=dict(traceback=str(traceback.format_exc()))) diff --git a/server/llm_engine_server/inference/sync_inference/start_fastapi_server.py b/server/llm_engine_server/inference/sync_inference/start_fastapi_server.py deleted file mode 100644 index c54b9e7b..00000000 --- a/server/llm_engine_server/inference/sync_inference/start_fastapi_server.py +++ /dev/null @@ -1,32 +0,0 @@ -import os -import subprocess - -from llm_engine_server.inference.common import unset_sensitive_envvars -from llm_engine_server.inference.sync_inference.constants import NUM_PROCESSES - -PORT = os.environ["PORT"] - - -def start_server(): - # TODO: HTTPS - # Copied from std-ml-srv - command = [ - "gunicorn", - "--bind", - f"[::]:{PORT}", - "--timeout", - "1200", - "--keep-alive", - "2", - "--worker-class", - "uvicorn.workers.UvicornWorker", - "--workers", - str(NUM_PROCESSES), - "llm_engine_server.inference.sync_inference.fastapi_server:app", - ] - unset_sensitive_envvars() - subprocess.run(command) - - -if __name__ == "__main__": - start_server() diff --git a/server/llm_engine_server/inference/user.Dockerfile b/server/llm_engine_server/inference/user.Dockerfile deleted file mode 100644 index 595097ed..00000000 --- a/server/llm_engine_server/inference/user.Dockerfile +++ /dev/null @@ -1,8 +0,0 @@ -ARG BASE_IMAGE -FROM ${BASE_IMAGE} - -ARG REQUIREMENTS_FILE -COPY --chown=root ${REQUIREMENTS_FILE} /app/llm_engine/llm_engine/inference/requirements.txt -RUN PIP_CONFIG_FILE=/kaniko/pip/codeartifact_pip_conf pip install -r /app/llm_engine/llm_engine/inference/requirements.txt - -ENV PYTHONPATH /app diff --git a/server/llm_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py b/server/llm_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py deleted file mode 100644 index 9c3860cc..00000000 --- a/server/llm_engine_server/infra/gateways/datadog_monitoring_metrics_gateway.py +++ /dev/null @@ -1,23 +0,0 @@ -from datadog import statsd -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway - - -class DatadogMonitoringMetricsGateway(MonitoringMetricsGateway): - def __init__(self): - self.tags = [f"env:{ml_infra_config().env}"] - - def emit_attempted_build_metric(self): - statsd.increment("scale_llm_engine_server.service_builder.attempt", tags=self.tags) - - def emit_successful_build_metric(self): - statsd.increment("scale_llm_engine_server.service_builder.success", tags=self.tags) - - def emit_docker_failed_build_metric(self): - statsd.increment("scale_llm_engine_server.service_builder.docker_failed", tags=self.tags) - - def emit_database_cache_hit_metric(self): - statsd.increment("scale_llm_engine_server.database_cache.hit", tags=self.tags) - - def emit_database_cache_miss_metric(self): - statsd.increment("scale_llm_engine_server.database_cache.miss", tags=self.tags) diff --git a/server/llm_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py b/server/llm_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py deleted file mode 100644 index a41ee417..00000000 --- a/server/llm_engine_server/infra/gateways/fake_monitoring_metrics_gateway.py +++ /dev/null @@ -1,44 +0,0 @@ -from collections import defaultdict - -from llm_engine_server.domain.gateways import MonitoringMetricsGateway - - -class FakeMonitoringMetricsGateway(MonitoringMetricsGateway): - def __init__(self): - self.attempted_build = 0 - self.successful_build = 0 - self.docker_failed_build = 0 - self.attempted_hook = defaultdict(int) - self.successful_hook = defaultdict(int) - self.database_cache_hit = 0 - self.database_cache_miss = 0 - - def reset(self): - self.attempted_build = 0 - self.successful_build = 0 - self.docker_failed_build = 0 - self.attempted_hook = defaultdict(int) - self.successful_hook = defaultdict(int) - self.database_cache_hit = 0 - self.database_cache_miss = 0 - - def emit_attempted_build_metric(self): - self.attempted_build += 1 - - def emit_successful_build_metric(self): - self.successful_build += 1 - - def emit_docker_failed_build_metric(self): - self.docker_failed_build += 1 - - def emit_attempted_post_inference_hook(self, hook: str): - self.attempted_hook[hook] += 1 - - def emit_successful_post_inference_hook(self, hook: str): - self.successful_hook[hook] += 1 - - def emit_database_cache_hit_metric(self): - self.database_cache_hit += 1 - - def emit_database_cache_miss_metric(self): - self.database_cache_miss += 1 diff --git a/server/llm_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py b/server/llm_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py deleted file mode 100644 index 0c812e8f..00000000 --- a/server/llm_engine_server/infra/gateways/live_streaming_model_endpoint_inference_gateway.py +++ /dev/null @@ -1,181 +0,0 @@ -from typing import Any, AsyncIterable, Dict - -import aiohttp -import orjson -import requests -import sseclient -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, - SyncEndpointPredictV1Response, - TaskStatus, -) -from llm_engine_server.common.env_vars import CIRCLECI, LOCAL -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import TooManyRequestsException, UpstreamServiceError -from llm_engine_server.domain.gateways.streaming_model_endpoint_inference_gateway import ( - StreamingModelEndpointInferenceGateway, -) -from llm_engine_server.infra.gateways.aiohttp_sse_client import EventSource -from llm_engine_server.infra.gateways.k8s_resource_parser import get_node_port -from orjson import JSONDecodeError -from tenacity import ( - AsyncRetrying, - RetryError, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - -logger = make_logger(filename_wo_ext(__file__)) - -SYNC_ENDPOINT_RETRIES = 5 # Must be an integer >= 0 -SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS = 10 - - -def _get_streaming_endpoint_url(deployment_name: str) -> str: - if CIRCLECI: - # Circle CI: a NodePort is used to expose the service - # The IP address is obtained from `minikube ip`. - protocol: str = "http" - hostname: str = f"192.168.49.2:{get_node_port(deployment_name)}" - elif LOCAL: - # local development: the svc.cluster.local address is only available w/in the k8s cluster - protocol = "https" - hostname = f"{deployment_name}.{ml_infra_config().dns_host_domain}" - else: - protocol = "http" - # no need to hit external DNS resolution if we're w/in the k8s cluster - hostname = f"{deployment_name}.{hmi_config.endpoint_namespace}.svc.cluster.local" - return f"{protocol}://{hostname}/stream" - - -def _serialize_json(data) -> str: - # Use orjson, which is faster and more correct than native Python json library. - # This is more important for sync endpoints, which are more latency-sensitive. - return orjson.dumps(data).decode() - - -class LiveStreamingModelEndpointInferenceGateway(StreamingModelEndpointInferenceGateway): - """ - Concrete implementation for an StreamingModelEndpointInferenceGateway. - - make_single_request() makes the streaming inference request to the endpoint - make_request_with_retries() wraps make_single_request() with retries - streaming_predict() wraps make_request_with_retries() and yields SyncEndpointPredictV1Response - """ - - def __init__(self, use_asyncio: bool): - self.use_asyncio = use_asyncio - - async def make_single_request(self, request_url: str, payload_json: Dict[str, Any]): - errored = False - if self.use_asyncio: - - async with aiohttp.ClientSession(json_serialize=_serialize_json) as aioclient: - aio_resp = await aioclient.post( - request_url, - json=payload_json, - headers={"Content-Type": "application/json"}, - ) - status = aio_resp.status - if status == 200: - async with EventSource(response=aio_resp) as event_source: - async for event in event_source: - yield event.data - else: - content = await aio_resp.read() - errored = True - else: - resp = requests.post( - request_url, - json=payload_json, - headers={"Content-Type": "application/json"}, - stream=True, - ) - client = sseclient.SSEClient(resp) - status = resp.status_code - if status == 200: - for event in client.events(): - yield event.data - else: - content = resp.content - errored = True - - # Need to have these exceptions raised outside the async context so that - # tenacity can properly capture them. - if errored: - if status == 429: - raise TooManyRequestsException("429 returned") - else: - raise UpstreamServiceError(status_code=status, content=content) - - async def make_request_with_retries( - self, - request_url: str, - payload_json: Dict[str, Any], - timeout_seconds: float, - num_retries: int, - ) -> AsyncIterable[Dict[str, Any]]: - # Copied from document-endpoint - # More details at https://tenacity.readthedocs.io/en/latest/#retrying-code-block - # Try/catch + for loop makes us retry only when we get a 429 from the synchronous endpoint. - # We should be creating a new requests Session each time, which should avoid sending - # requests to the same endpoint. This is admittedly a hack until we get proper - # least-outstanding-requests load balancing to our http endpoints - - try: - async for attempt in AsyncRetrying( - stop=stop_after_attempt(num_retries + 1), - retry=retry_if_exception_type(TooManyRequestsException), - wait=wait_exponential(multiplier=1, min=1, max=timeout_seconds), - ): - with attempt: - logger.info(f"Retry number {attempt.retry_state.attempt_number}") - response = self.make_single_request(request_url, payload_json) - async for item in response: - yield orjson.loads(item) - return - except RetryError: - logger.warning("Hit max # of retries, returning 429 to client") - raise UpstreamServiceError(status_code=429, content=b"Too many concurrent requests") - except JSONDecodeError: - logger.exception("JSONDecodeError") - raise UpstreamServiceError(status_code=500, content=b"JSONDecodeError") - - # Never reached because tenacity should throw a RetryError if we exit the for loop. - # This is for mypy. - # pragma: no cover - raise Exception("Should never reach this line") - - async def streaming_predict( - self, topic: str, predict_request: EndpointPredictV1Request - ) -> AsyncIterable[SyncEndpointPredictV1Response]: - deployment_url = _get_streaming_endpoint_url(topic) - - try: - response = self.make_request_with_retries( - request_url=deployment_url, - payload_json=predict_request.dict(), - timeout_seconds=SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS, - num_retries=SYNC_ENDPOINT_RETRIES, - ) - async for item in response: - yield SyncEndpointPredictV1Response(status=TaskStatus.SUCCESS, result=item) - except UpstreamServiceError as exc: - logger.error(f"Service error on sync task: {exc.content!r}") - try: - error_json = orjson.loads(exc.content.decode("utf-8")) - result_traceback = ( - error_json.get("detail", {}).get("traceback") - if isinstance(error_json, dict) - else None - ) - except JSONDecodeError: - result_traceback = exc.content.decode() - - yield SyncEndpointPredictV1Response( - status=TaskStatus.FAILURE, - traceback=result_traceback, - ) diff --git a/server/llm_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py b/server/llm_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py deleted file mode 100644 index 5022aeed..00000000 --- a/server/llm_engine_server/infra/gateways/live_sync_model_endpoint_inference_gateway.py +++ /dev/null @@ -1,158 +0,0 @@ -from typing import Any, Dict - -import aiohttp -import orjson -import requests -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, - SyncEndpointPredictV1Response, - TaskStatus, -) -from llm_engine_server.common.env_vars import CIRCLECI, LOCAL -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.exceptions import TooManyRequestsException, UpstreamServiceError -from llm_engine_server.domain.gateways.sync_model_endpoint_inference_gateway import ( - SyncModelEndpointInferenceGateway, -) -from llm_engine_server.infra.gateways.k8s_resource_parser import get_node_port -from orjson import JSONDecodeError -from tenacity import ( - AsyncRetrying, - RetryError, - retry_if_exception_type, - stop_after_attempt, - wait_exponential, -) - -logger = make_logger(filename_wo_ext(__file__)) - -SYNC_ENDPOINT_RETRIES = 5 # Must be an integer >= 0 -SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS = 10 - - -def _get_sync_endpoint_url(deployment_name: str) -> str: - if CIRCLECI: - # Circle CI: a NodePort is used to expose the service - # The IP address is obtained from `minikube ip`. - protocol: str = "http" - hostname: str = f"192.168.49.2:{get_node_port(deployment_name)}" - elif LOCAL: - # local development: the svc.cluster.local address is only available w/in the k8s cluster - protocol = "https" - hostname = f"{deployment_name}.{ml_infra_config().dns_host_domain}" - else: - protocol = "http" - # no need to hit external DNS resolution if we're w/in the k8s cluster - hostname = f"{deployment_name}.{hmi_config.endpoint_namespace}.svc.cluster.local" - return f"{protocol}://{hostname}/predict" - - -def _serialize_json(data) -> str: - # Use orjson, which is faster and more correct than native Python json library. - # This is more important for sync endpoints, which are more latency-sensitive. - return orjson.dumps(data).decode() - - -class LiveSyncModelEndpointInferenceGateway(SyncModelEndpointInferenceGateway): - """ - Concrete implementation for an SyncModelEndpointInferenceGateway. - """ - - def __init__(self, use_asyncio: bool): - self.use_asyncio = use_asyncio - - async def make_single_request(self, request_url: str, payload_json: Dict[str, Any]): - if self.use_asyncio: - async with aiohttp.ClientSession(json_serialize=_serialize_json) as client: - aio_resp = await client.post( - request_url, - json=payload_json, - headers={"Content-Type": "application/json"}, - ) - status = aio_resp.status - if status == 200: - return await aio_resp.json() - content = await aio_resp.read() - else: - resp = requests.post( - request_url, - json=payload_json, - headers={"Content-Type": "application/json"}, - ) - status = resp.status_code - if status == 200: - return resp.json() - content = resp.content - - # Need to have these exceptions raised outside the async context so that - # tenacity can properly capture them. - if status == 429: - raise TooManyRequestsException("429 returned") - else: - raise UpstreamServiceError(status_code=status, content=content) - - async def make_request_with_retries( - self, - request_url: str, - payload_json: Dict[str, Any], - timeout_seconds: float, - num_retries: int, - ) -> Dict[str, Any]: - # Copied from document-endpoint - # More details at https://tenacity.readthedocs.io/en/latest/#retrying-code-block - # Try/catch + for loop makes us retry only when we get a 429 from the synchronous endpoint. - # We should be creating a new requests Session each time, which should avoid sending - # requests to the same endpoint. This is admittedly a hack until we get proper - # least-outstanding-requests load balancing to our http endpoints - - try: - async for attempt in AsyncRetrying( - stop=stop_after_attempt(num_retries + 1), - retry=retry_if_exception_type(TooManyRequestsException), - wait=wait_exponential(multiplier=1, min=1, max=timeout_seconds), - ): - with attempt: - logger.info(f"Retry number {attempt.retry_state.attempt_number}") - return await self.make_single_request(request_url, payload_json) - except RetryError: - logger.warning("Hit max # of retries, returning 429 to client") - raise UpstreamServiceError(status_code=429, content=b"Too many concurrent requests") - - # Never reached because tenacity should throw a RetryError if we exit the for loop. - # This is for mypy. - # pragma: no cover - return {} - - async def predict( - self, topic: str, predict_request: EndpointPredictV1Request - ) -> SyncEndpointPredictV1Response: - deployment_url = _get_sync_endpoint_url(topic) - - try: - response = await self.make_request_with_retries( - request_url=deployment_url, - payload_json=predict_request.dict(), - timeout_seconds=SYNC_ENDPOINT_MAX_TIMEOUT_SECONDS, - num_retries=SYNC_ENDPOINT_RETRIES, - ) - except UpstreamServiceError as exc: - logger.error(f"Service error on sync task: {exc.content!r}") - try: - error_json = orjson.loads(exc.content.decode("utf-8")) - result_traceback = ( - error_json.get("detail", {}).get("traceback") - if isinstance(error_json, dict) - else None - ) - return SyncEndpointPredictV1Response( - status=TaskStatus.FAILURE, - traceback=result_traceback, - ) - except JSONDecodeError: - return SyncEndpointPredictV1Response( - status=TaskStatus.FAILURE, traceback=exc.content.decode() - ) - - return SyncEndpointPredictV1Response(status=TaskStatus.SUCCESS, result=response) diff --git a/server/llm_engine_server/infra/gateways/resources/sqs_endpoint_resource_delegate.py b/server/llm_engine_server/infra/gateways/resources/sqs_endpoint_resource_delegate.py deleted file mode 100644 index 85aaa945..00000000 --- a/server/llm_engine_server/infra/gateways/resources/sqs_endpoint_resource_delegate.py +++ /dev/null @@ -1,48 +0,0 @@ -from abc import ABC, abstractmethod -from typing import Any, Dict, NamedTuple, Sequence - -from mypy_boto3_sqs.type_defs import GetQueueAttributesResultTypeDef - -__all__: Sequence[str] = ( - "SQSQueueInfo", - "SQSEndpointResourceDelegate", -) - - -class SQSQueueInfo(NamedTuple): - queue_name: str - queue_url: str - - -class SQSEndpointResourceDelegate(ABC): - """ - Base class for an interactor with SQS. This is used by the LiveEndpointResourceGateway. - """ - - @abstractmethod - async def create_queue_if_not_exists( - self, - endpoint_id: str, - endpoint_name: str, - endpoint_created_by: str, - endpoint_labels: Dict[str, Any], - ) -> SQSQueueInfo: - """ - Creates an SQS queue associated with the given endpoint_id. Other fields are set as tags on the queue. - """ - - @abstractmethod - async def delete_queue(self, endpoint_id: str) -> None: - """ - Deletes an SQS queue associated with the given endpoint_id. This is a no-op if the queue does not exist. - """ - - @abstractmethod - async def get_queue_attributes(self, endpoint_id: str) -> GetQueueAttributesResultTypeDef: - """ - Get all attributes of an SQS queue. - """ - - @staticmethod - def endpoint_id_to_queue_name(endpoint_id: str) -> str: - return f"llm-engine-endpoint-id-{endpoint_id}" diff --git a/server/llm_engine_server/infra/repositories/ecr_docker_repository.py b/server/llm_engine_server/infra/repositories/ecr_docker_repository.py deleted file mode 100644 index eb61287d..00000000 --- a/server/llm_engine_server/infra/repositories/ecr_docker_repository.py +++ /dev/null @@ -1,44 +0,0 @@ -from typing import Optional - -from llm_engine_server.common.dtos.docker_repository import BuildImageRequest, BuildImageResponse -from llm_engine_server.core.config import ml_infra_config -from llm_engine_server.core.docker.ecr import image_exists as ecr_image_exists -from llm_engine_server.core.docker.remote_build import build_remote_block -from llm_engine_server.domain.repositories import DockerRepository - - -class ECRDockerRepository(DockerRepository): - def image_exists( - self, image_tag: str, repository_name: str, aws_profile: Optional[str] = None - ) -> bool: - return ecr_image_exists( - image_tag=image_tag, - repository_name=repository_name, - aws_profile=aws_profile, - ) - - def get_image_url(self, image_tag: str, repository_name: str) -> str: - return f"{ml_infra_config().docker_repo_prefix}/{repository_name}:{image_tag}" - - def build_image(self, image_params: BuildImageRequest) -> BuildImageResponse: - folders_to_include = [ - "llm_engine", - ] - if image_params.requirements_folder: - folders_to_include.append(image_params.requirements_folder) - - build_args = { - "BASE_IMAGE": image_params.base_image, - } - - if image_params.substitution_args: - build_args.update(image_params.substitution_args) - - build_result = build_remote_block( - context=image_params.base_path, - dockerfile=image_params.dockerfile, - repotags=[f"{image_params.repo}:{image_params.image_tag}"], - folders_to_include=folders_to_include, - build_args=build_args, - ) - return BuildImageResponse(status=build_result.status, logs=build_result.logs) diff --git a/server/llm_engine_server/infra/services/image_cache_service.py b/server/llm_engine_server/infra/services/image_cache_service.py deleted file mode 100644 index a48f8c95..00000000 --- a/server/llm_engine_server/infra/services/image_cache_service.py +++ /dev/null @@ -1,107 +0,0 @@ -from datetime import datetime -from typing import Dict, NamedTuple, Tuple - -from llm_engine_server.core.loggers import filename_wo_ext, make_logger -from llm_engine_server.domain.entities import GpuType, ModelEndpointInfraState -from llm_engine_server.domain.repositories import DockerRepository -from llm_engine_server.infra.gateways.resources.image_cache_gateway import ( - CachedImages, - ImageCacheGateway, -) -from llm_engine_server.infra.repositories.model_endpoint_record_repository import ( - ModelEndpointRecordRepository, -) - -logger = make_logger(filename_wo_ext(__name__)) - -IMAGES_TO_CACHE_PER_INSTANCE_TYPE = 32 - -CachePriority = NamedTuple( - "CachePriority", - ( - ("is_high_priority", int), - ("has_no_available_workers", int), - ("last_updated_at", datetime), - ), -) - - -class ImageCacheService: - """ - Represents reading from k8s and writing images to the k8s image cache. - """ - - def __init__( - self, - model_endpoint_record_repository: ModelEndpointRecordRepository, - image_cache_gateway: ImageCacheGateway, - docker_repository: DockerRepository, - ): - self.model_endpoint_record_repository = model_endpoint_record_repository - self.image_cache_gateway = image_cache_gateway - self.docker_repository = docker_repository - - async def execute(self, endpoint_infra_states: Dict[str, Tuple[bool, ModelEndpointInfraState]]): - images_to_cache_priority: Dict[str, Dict[str, CachePriority]] = { - "cpu": {}, - "a10": {}, - "a100": {}, - "t4": {}, - } - for endpoint_id, (_, state) in endpoint_infra_states.items(): - record = await self.model_endpoint_record_repository.get_model_endpoint_record( - endpoint_id - ) - - if record is None: - continue - - last_updated_at = record.last_updated_at or datetime.min - has_no_available_workers = int(state.deployment_state.available_workers == 0) - is_high_priority = int(state.high_priority is True) - - # TODO: Adding for image cache stability and to make it faster. Remove this - # condition when things are proven to run smoothly. - if not state.high_priority: - continue - - cache_priority = CachePriority( - is_high_priority=is_high_priority, - has_no_available_workers=has_no_available_workers, - last_updated_at=last_updated_at, - ) - - image_repository_and_tag = state.image.split("/", 1)[1] - repository_name, image_tag = image_repository_and_tag.split(":") - if state.resource_state.gpus == 0 and ( - ( - state.image not in images_to_cache_priority["cpu"] - or last_updated_at - > images_to_cache_priority["cpu"][state.image].last_updated_at - ) - and self.docker_repository.image_exists(image_tag, repository_name) - ): - images_to_cache_priority["cpu"][state.image] = cache_priority - elif state.resource_state.gpus > 0: - for gpu_type, key in [ - (GpuType.NVIDIA_AMPERE_A10, "a10"), - (GpuType.NVIDIA_AMPERE_A100, "a100"), - (GpuType.NVIDIA_TESLA_T4, "t4"), - ]: - if state.resource_state.gpu_type == gpu_type and ( - ( - state.image not in images_to_cache_priority[key] - or last_updated_at - > images_to_cache_priority[key][state.image].last_updated_at - ) - and self.docker_repository.image_exists(image_tag, repository_name) - ): - images_to_cache_priority[key][state.image] = cache_priority - - images_to_cache = CachedImages(cpu=[], a10=[], a100=[], t4=[]) - for key, val in images_to_cache_priority.items(): - images_to_cache[key] = sorted( # type: ignore - val.keys(), key=lambda image: val[image], reverse=True - )[:IMAGES_TO_CACHE_PER_INSTANCE_TYPE] - - await self.image_cache_gateway.create_or_update_image_cache(images_to_cache) diff --git a/server/llm_engine_server/scripts/autogenerate_client_and_docs.py b/server/llm_engine_server/scripts/autogenerate_client_and_docs.py deleted file mode 100644 index 973e7008..00000000 --- a/server/llm_engine_server/scripts/autogenerate_client_and_docs.py +++ /dev/null @@ -1,39 +0,0 @@ -import json -import subprocess -from pathlib import Path - -from llm_engine_server.api.app import app - -MODULE_PATH = Path(__file__).resolve() -LLM_ENGINE_SERVICE_BASE = MODULE_PATH.parents[2].resolve() -OPENAPI_PATH = (LLM_ENGINE_SERVICE_BASE / "clients/openapi.json").resolve() -LANGUAGE_TO_GENERATOR_NAME = dict(python="python", typescript="typescript-axios") - - -def dump_openapi(openapi_path: str): - """Writes the OpenAPI schema to the specified path.""" - with open(openapi_path, "w") as file: - schema = app.openapi() - file.write(json.dumps(schema, indent=4, sort_keys=True)) - - -def run_openapi_generator(): - """Launches a subprocess with the OpenAPI generator.""" - print("🏭 Generating client") - command = ["docker-compose run openapi-generator-cli"] - subprocess.run( - command, - cwd=str((LLM_ENGINE_SERVICE_BASE / "../ml_infra_core").resolve()), - check=True, - shell=True, - ) - - -def entrypoint(): - """Entrypoint for autogenerating client and documentation.""" - dump_openapi(str(OPENAPI_PATH)) - run_openapi_generator() - - -if __name__ == "__main__": - entrypoint() diff --git a/server/llm_engine_server/scripts/copy_to_public_client.sh b/server/llm_engine_server/scripts/copy_to_public_client.sh deleted file mode 100755 index a40afddf..00000000 --- a/server/llm_engine_server/scripts/copy_to_public_client.sh +++ /dev/null @@ -1,11 +0,0 @@ -#!/bin/bash -# Usage: bash build_and_publish_to_codeartifact.sh $PATH_TO_PRIVATE_CLIENT $PATH_TO_PUBLIC_CLIENT - -set -e -PRIVATE_CLIENT_ROOT=$1 -PUBLIC_CLIENT_ROOT=$2 - -rm -rf $PUBLIC_CLIENT_ROOT/launch/api_client/* -cp -r $PRIVATE_CLIENT_ROOT/launch_client/* $PUBLIC_CLIENT_ROOT/launch/api_client/ - -sed -i '' 's/launch_client/launch.api_client/g' $(find $PUBLIC_CLIENT_ROOT/launch/api_client -type f -name '*\.py') diff --git a/server/llm_engine_server/service_builder/celery.py b/server/llm_engine_server/service_builder/celery.py deleted file mode 100644 index 57c9f623..00000000 --- a/server/llm_engine_server/service_builder/celery.py +++ /dev/null @@ -1,13 +0,0 @@ -from llm_engine_server.core.celery import celery_app -from llm_engine_server.core.config import ml_infra_config - -service_builder_service = celery_app( - name="llm_engine_server.service_builder", - modules=[ - "llm_engine_server.service_builder.tasks_v1", - ], - s3_bucket=ml_infra_config().s3_bucket, -) - -if __name__ == "__main__": - service_builder_service.start() diff --git a/server/llm_engine_server/service_builder/tasks_v1.py b/server/llm_engine_server/service_builder/tasks_v1.py deleted file mode 100644 index 8548f149..00000000 --- a/server/llm_engine_server/service_builder/tasks_v1.py +++ /dev/null @@ -1,99 +0,0 @@ -import asyncio -import os -from typing import Any, Dict - -import aioredis -from celery.signals import worker_process_init -from llm_engine_server.common.config import hmi_config -from llm_engine_server.common.constants import READYZ_FPATH -from llm_engine_server.common.dtos.endpoint_builder import ( - BuildEndpointRequest, - BuildEndpointResponse, -) -from llm_engine_server.common.env_vars import CIRCLECI -from llm_engine_server.core.fake_notification_gateway import FakeNotificationGateway -from llm_engine_server.core.notification_gateway import NotificationGateway -from llm_engine_server.db.base import SessionAsyncNullPool -from llm_engine_server.domain.gateways.monitoring_metrics_gateway import MonitoringMetricsGateway -from llm_engine_server.infra.gateways import FakeMonitoringMetricsGateway, S3FilesystemGateway -from llm_engine_server.infra.gateways.resources.fake_sqs_endpoint_resource_delegate import ( - FakeSQSEndpointResourceDelegate, -) -from llm_engine_server.infra.gateways.resources.k8s_endpoint_resource_delegate import ( - set_lazy_load_kubernetes_clients, -) -from llm_engine_server.infra.gateways.resources.live_endpoint_resource_gateway import ( - LiveEndpointResourceGateway, -) -from llm_engine_server.infra.gateways.resources.live_sqs_endpoint_resource_delegate import ( - LiveSQSEndpointResourceDelegate, -) -from llm_engine_server.infra.gateways.resources.sqs_endpoint_resource_delegate import ( - SQSEndpointResourceDelegate, -) -from llm_engine_server.infra.repositories import ( - DbModelEndpointRecordRepository, - ECRDockerRepository, - RedisFeatureFlagRepository, - RedisModelEndpointCacheRepository, -) -from llm_engine_server.infra.services import LiveEndpointBuilderService -from llm_engine_server.service_builder.celery import service_builder_service - -# Need to disable lazy loading of k8s clients because each event loop should contain its own k8s -# client, which constructs the aiohttp.ClientSession in the event loop. -set_lazy_load_kubernetes_clients(False) - - -async def _build_endpoint( - build_endpoint_request: BuildEndpointRequest, -) -> BuildEndpointResponse: - session = SessionAsyncNullPool - pool = aioredis.BlockingConnectionPool.from_url(hmi_config.cache_redis_url) - redis = aioredis.Redis(connection_pool=pool) - sqs_delegate: SQSEndpointResourceDelegate - notification_gateway: NotificationGateway - if CIRCLECI: - sqs_delegate = FakeSQSEndpointResourceDelegate() - else: - sqs_delegate = LiveSQSEndpointResourceDelegate( - sqs_profile=os.getenv("SQS_PROFILE", hmi_config.sqs_profile) - ) - monitoring_metrics_gateway: MonitoringMetricsGateway - monitoring_metrics_gateway = FakeMonitoringMetricsGateway() - notification_gateway = FakeNotificationGateway() - - service = LiveEndpointBuilderService( - docker_repository=ECRDockerRepository(), - resource_gateway=LiveEndpointResourceGateway(sqs_delegate=sqs_delegate), - monitoring_metrics_gateway=monitoring_metrics_gateway, - model_endpoint_record_repository=DbModelEndpointRecordRepository( - monitoring_metrics_gateway=monitoring_metrics_gateway, - session=session, - read_only=False, - ), - model_endpoint_cache_repository=RedisModelEndpointCacheRepository(redis_client=redis), - filesystem_gateway=S3FilesystemGateway(), - notification_gateway=notification_gateway, - feature_flag_repo=RedisFeatureFlagRepository(redis_client=redis), - ) - response = await service.build_endpoint(build_endpoint_request) - await redis.close() - await pool.disconnect() - return response - - -@worker_process_init.connect -def init_worker(*args, **kwargs): - # k8s health check - with open(READYZ_FPATH, "w") as f: - f.write("READY") - - -@service_builder_service.task -def build_endpoint(build_endpoint_request_json: Dict[str, Any]) -> Dict[str, str]: - build_endpoint_request: BuildEndpointRequest = BuildEndpointRequest.parse_obj( - build_endpoint_request_json - ) - result = asyncio.run(_build_endpoint(build_endpoint_request)) - return result.dict() diff --git a/server/mypy.ini b/server/mypy.ini deleted file mode 100644 index 316c36ef..00000000 --- a/server/mypy.ini +++ /dev/null @@ -1,21 +0,0 @@ -[mypy] -ignore_missing_imports = True -follow_imports = silent -show_column_numbers = True -namespace_packages = True -explicit_package_bases = True -strict_optional = True -plugins = pydantic.mypy -exclude = clients - -[mypy-llm_engine_server.core.*] -ignore_errors = True - -[mypy-llm_engine_server.db.*] -ignore_errors = True - -[mypy-llm_engine_server.infra.repositories.*] -ignore_errors = True - -[mypy-tests.*] -ignore_errors = True diff --git a/server/pyproject.toml b/server/pyproject.toml deleted file mode 100644 index 0a7ba88b..00000000 --- a/server/pyproject.toml +++ /dev/null @@ -1,6 +0,0 @@ -[build-system] -requires = [ - "setuptools", - "wheel", -] -build-backend = 'setuptools.build_meta' diff --git a/server/requirements.in b/server/requirements.in deleted file mode 100644 index 8d405e0f..00000000 --- a/server/requirements.in +++ /dev/null @@ -1,49 +0,0 @@ -GitPython~=3.0 -Jinja2==3.0.3 # version 3.1.0 had a bug -aiohttp~=3.8 -aioredis~=2.0 -alembic==1.8.1 -asyncpg==0.27.0 -boto3-stubs[essential]==1.26.67 -boto3~=1.21 -botocore~=1.24 -build==0.8.0 -celery[redis,sqs,tblib]~=5.2 -click~=8.1 -cloudpickle==2.1.0 -dataclasses-json>=0.5.7 -datadog-api-client==2.11.0 -datadog~=0.46.0 -ddtrace~=0.49.2 -deprecation~=2.1 -docker~=5.0 -fastapi==0.78.0 -gitdb2~=2.0 -gunicorn~=20.0 -httptools==0.5.0 -json-log-formatter~=0.3 -kubeconfig~=1.1 -kubernetes-asyncio==24.2.2 -kubernetes~=25.3.0 -orjson==3.8.6 -protobuf~=3.20 -psycopg2-binary==2.9.3 -py-xid==0.3.0 -pycurl~=7.44 # For celery[sqs] -pydantic~=1.10 -quart==0.18.3 -requests-auth-aws-sigv4~=0.7 -requests~=2.25 -rich~=12.6 -sh~=1.13 -smart-open~=5.2 -sqlalchemy[asyncio]==2.0.4 -sse-starlette==1.6.1 -sseclient-py==1.7.2 -tenacity>=6.0.0,<=6.2.0 -testing-postgresql==1.3.0 -tqdm~=4.64 -twine==3.7.1 -uvicorn==0.17.6 -uvloop==0.17.0 -yarl~=1.4 diff --git a/server/requirements.txt b/server/requirements.txt deleted file mode 100644 index 69424330..00000000 --- a/server/requirements.txt +++ /dev/null @@ -1,443 +0,0 @@ -# -# This file is autogenerated by pip-compile with python 3.8 -# To update, run: -# -# pip-compile --allow-unsafe --no-emit-index-url --no-emit-trusted-host --output-file=requirements.txt requirements.in -# -aiofiles==23.1.0 - # via quart -aiohttp==3.8.4 - # via - # -r requirements.in - # kubernetes-asyncio -aioredis==2.0.1 - # via -r requirements.in -aiosignal==1.3.1 - # via aiohttp -alembic==1.8.1 - # via -r requirements.in -amqp==5.1.1 - # via kombu -anyio==3.7.1 - # via starlette -asgiref==3.7.2 - # via uvicorn -asn1crypto==1.5.1 - # via scramp -async-timeout==4.0.2 - # via - # aiohttp - # aioredis - # redis -asyncpg==0.27.0 - # via -r requirements.in -attrs==23.1.0 - # via - # aiohttp - # ddtrace-scale -backports-zoneinfo[tzdata]==0.2.1 - # via - # celery - # kombu -billiard==4.1.0 - # via celery -bleach==6.0.0 - # via readme-renderer -blinker==1.6.2 - # via quart -boto3==1.28.1 - # via - # -r requirements.in - # celery -boto3-stubs[essential]==1.26.67 - # via -r requirements.in -botocore==1.31.1 - # via - # -r requirements.in - # boto3 - # s3transfer -botocore-stubs==1.29.165 - # via boto3-stubs -build==0.8.0 - # via -r requirements.in -cachetools==5.3.1 - # via google-auth -celery[redis,sqs,tblib]==5.3.1 - # via -r requirements.in -certifi==2023.5.7 - # via - # datadog-api-client - # kubernetes - # kubernetes-asyncio - # requests -charset-normalizer==3.2.0 - # via - # aiohttp - # requests -click==8.1.4 - # via - # -r requirements.in - # celery - # click-didyoumean - # click-plugins - # click-repl - # quart - # uvicorn -click-didyoumean==0.3.0 - # via celery -click-plugins==1.1.1 - # via celery -click-repl==0.3.0 - # via celery -cloudpickle==2.1.0 - # via -r requirements.in -colorama==0.4.6 - # via twine -commonmark==0.9.1 - # via rich -dataclasses-json==0.5.9 - # via -r requirements.in -datadog-api-client==2.11.0 - # via -r requirements.in -datadog==0.46.0 - # via -r requirements.in -ddtrace==0.49.2 - # via -r requirements.in -deprecation==2.1.0 - # via -r requirements.in -docker==5.0.3 - # via -r requirements.in -docutils==0.20.1 - # via readme-renderer -exceptiongroup==1.1.2 - # via anyio -fastapi==0.78.0 - # via -r requirements.in -frozenlist==1.3.3 - # via - # aiohttp - # aiosignal -gitdb==4.0.10 - # via gitpython -gitdb2==2.0.6 - # via -r requirements.in -gitpython==3.1.32 - # via -r requirements.in -google-auth==2.21.0 - # via kubernetes -greenlet==2.0.2 - # via sqlalchemy -gunicorn==20.1.0 - # via -r requirements.in -h11==0.14.0 - # via - # hypercorn - # uvicorn - # wsproto -h2==4.1.0 - # via hypercorn -hpack==4.0.0 - # via h2 -httptools==0.5.0 - # via -r requirements.in -hypercorn==0.14.4 - # via quart -hyperframe==6.0.1 - # via h2 -idna==3.4 - # via - # anyio - # requests - # yarl -importlib-metadata==6.8.0 - # via - # alembic - # keyring - # quart - # twine -importlib-resources==6.0.0 - # via - # alembic - # keyring -itsdangerous==2.1.2 - # via quart -jaraco-classes==3.3.0 - # via keyring -jinja2==3.0.3 - # via - # -r requirements.in - # quart -jmespath==1.0.1 - # via - # boto3 - # botocore -json-log-formatter==0.5.2 - # via -r requirements.in -keyring==24.2.0 - # via twine -kombu==5.3.1 - # via celery -kubeconfig==1.1.1 - # via -r requirements.in -kubernetes==25.3.0 - # via -r requirements.in -kubernetes-asyncio==24.2.2 - # via -r requirements.in -mako==1.2.4 - # via alembic -markupsafe==2.1.3 - # via - # jinja2 - # mako - # quart - # werkzeug -marshmallow==3.19.0 - # via - # dataclasses-json - # marshmallow-enum -marshmallow-enum==1.5.1 - # via dataclasses-json -more-itertools==9.1.0 - # via jaraco-classes -multidict==6.0.4 - # via - # aiohttp - # yarl -mypy-boto3-cloudformation==1.26.156 - # via boto3-stubs -mypy-boto3-dynamodb==1.26.164 - # via boto3-stubs -mypy-boto3-ec2==1.26.157 - # via boto3-stubs -mypy-boto3-lambda==1.26.163 - # via boto3-stubs -mypy-boto3-rds==1.26.163 - # via boto3-stubs -mypy-boto3-s3==1.26.163 - # via boto3-stubs -mypy-boto3-sqs==1.26.148 - # via boto3-stubs -mypy-extensions==1.0.0 - # via typing-inspect -oauthlib==3.2.2 - # via requests-oauthlib -orjson==3.8.6 - # via -r requirements.in -packaging==23.1 - # via - # build - # ddtrace-scale - # deprecation - # marshmallow -pep517==0.13.0 - # via build -pg8000==1.29.8 - # via testing-postgresql -pkginfo==1.9.6 - # via twine -priority==2.0.0 - # via hypercorn -prompt-toolkit==3.0.39 - # via click-repl -protobuf==3.20.3 - # via - # -r requirements.in - # ddtrace-scale -psycopg2-binary==2.9.3 - # via -r requirements.in -py-xid==0.3.0 - # via -r requirements.in -pyasn1==0.5.0 - # via - # pyasn1-modules - # rsa -pyasn1-modules==0.3.0 - # via google-auth -pycurl==7.45.2 - # via - # -r requirements.in - # celery -pydantic==1.10.11 - # via - # -r requirements.in - # fastapi -pygments==2.15.1 - # via - # readme-renderer - # rich -python-dateutil==2.8.2 - # via - # botocore - # celery - # datadog-api-client - # kubernetes - # kubernetes-asyncio - # pg8000 -pyyaml==6.0 - # via - # kubeconfig - # kubernetes - # kubernetes-asyncio -quart==0.18.3 - # via -r requirements.in -readme-renderer==40.0 - # via twine -redis==4.6.0 - # via celery -requests==2.31.0 - # via - # -r requirements.in - # datadog-scale - # docker - # kubernetes - # requests-auth-aws-sigv4 - # requests-oauthlib - # requests-toolbelt - # twine -requests-auth-aws-sigv4==0.7 - # via -r requirements.in -requests-oauthlib==1.3.1 - # via kubernetes -requests-toolbelt==1.0.0 - # via twine -rfc3986==2.0.0 - # via twine -rich==12.6.0 - # via -r requirements.in -rsa==4.9 - # via google-auth -s3transfer==0.6.1 - # via boto3 -scramp==1.4.4 - # via pg8000 -sh==1.14.3 - # via -r requirements.in -six==1.16.0 - # via - # bleach - # ddtrace-scale - # google-auth - # kubernetes - # kubernetes-asyncio - # python-dateutil - # tenacity -smart-open==5.2.1 - # via -r requirements.in -smmap==5.0.0 - # via - # gitdb - # smmap2 -smmap2==3.0.1 - # via gitdb2 -sniffio==1.3.0 - # via anyio -sqlalchemy[asyncio]==2.0.4 - # via - # -r requirements.in - # alembic -sse-starlette==1.6.1 - # via -r requirements.in -sseclient-py==1.7.2 - # via -r requirements.in -starlette==0.19.1 - # via - # fastapi - # sse-starlette -tblib==2.0.0 - # via celery -tenacity==6.2.0 - # via - # -r requirements.in - # ddtrace-scale -testing-common-database==2.0.3 - # via testing-postgresql -testing-postgresql==1.3.0 - # via -r requirements.in -tomli==2.0.1 - # via - # build - # hypercorn - # pep517 -tqdm==4.65.0 - # via - # -r requirements.in - # twine -twine==3.7.1 - # via -r requirements.in -types-awscrt==0.16.23 - # via - # botocore-stubs - # types-s3transfer -types-s3transfer==0.6.1 - # via boto3-stubs -typing-extensions==4.7.1 - # via - # aioredis - # asgiref - # boto3-stubs - # botocore-stubs - # datadog-api-client - # kombu - # mypy-boto3-cloudformation - # mypy-boto3-dynamodb - # mypy-boto3-ec2 - # mypy-boto3-lambda - # mypy-boto3-rds - # mypy-boto3-s3 - # mypy-boto3-sqs - # pydantic - # rich - # sqlalchemy - # starlette - # typing-inspect -typing-inspect==0.9.0 - # via dataclasses-json -tzdata==2023.3 - # via - # backports-zoneinfo - # celery -urllib3==1.26.16 - # via - # botocore - # celery - # datadog-api-client - # google-auth - # kubernetes - # kubernetes-asyncio - # requests -uvicorn==0.17.6 - # via -r requirements.in -uvloop==0.17.0 - # via -r requirements.in -vine==5.0.0 - # via - # amqp - # celery - # kombu -wcwidth==0.2.6 - # via prompt-toolkit -webencodings==0.5.1 - # via bleach -websocket-client==1.6.1 - # via - # docker - # kubernetes -werkzeug==2.3.6 - # via quart -wsproto==1.2.0 - # via hypercorn -yarl==1.9.2 - # via - # -r requirements.in - # aiohttp -zipp==3.16.0 - # via - # importlib-metadata - # importlib-resources - -# The following packages are considered to be unsafe in a requirements file: -setuptools==68.0.0 - # via - # gunicorn - # kubernetes - # kubernetes-asyncio diff --git a/server/requirements_override.txt b/server/requirements_override.txt deleted file mode 100644 index 0c282e6e..00000000 --- a/server/requirements_override.txt +++ /dev/null @@ -1,4 +0,0 @@ -# Consists of packages that are incompatible with requirements.txt -aioboto3==10.0.0 -aiobotocore[boto3]~=2.3.4 -urllib3==1.26.11 diff --git a/server/service_configs/service_config.yaml b/server/service_configs/service_config.yaml deleted file mode 100644 index 1f9c4ef2..00000000 --- a/server/service_configs/service_config.yaml +++ /dev/null @@ -1,66 +0,0 @@ -# Default Configs - -# Endpoint config -# K8s namespace the endpoints will be created in -endpoint_namespace: llm-engine - -# Asynchronous endpoints -sqs_profile: default -sqs_queue_policy_template: > - { - "Version": "2012-10-17", - "Id": "__default_policy_ID", - "Statement": [ - { - "Sid": "__owner_statement", - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:root" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - }, - { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:role/default" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - }, - { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:role/ml_llm_engine" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - } - ] - } - -sqs_queue_tag_template: > - { - "infra.scale.com/product": "MLInfraLLMEngineSQS", - "infra.scale.com/team": "${team}", - "infra.scale.com/contact": "yi.xu@scale.com", - "infra.scale.com/customer": "AllCustomers", - "infra.scale.com/financialOwner": "yi.xu@scale.com", - "Spellbook-Serve-Endpoint-Id": "${endpoint_id}", - "Spellbook-Serve-Endpoint-Name": "${endpoint_name}", - "Spellbook-Serve-Endpoint-Created-By": "${endpoint_created_by}" - } - -# resultsS3Bucket (i.e. where HMI will store model inference results) is currently determined on endpoint creation -# via a request - -# modelBundleS3Bucket (i.e. where model bundles are stored) is not determined by any HMI code, but instead -# by some scaleapi routing layer code for scale-hosted HMI, and by request parameters in general. - -# Currently, the celery redis used is defaulted to scale's celery redis, and is hardcoded inside scaleml's celery impl. -# We'll need to bundle this celery implementation along for open-source hosting. - -# There's a separate piece of infra that caches k8s state onto redis, so we need a url to it -cache_redis_url: redis://redis-elasticache-message-broker.ml-internal.scale.com:6379/15 -s3_file_llm_fine_tuning_job_repository: "s3://scale-ml/hosted-model-inference/llm-ft-job-repository/circleci" -datadog_trace_enabled: false diff --git a/server/service_configs/service_config_circleci.yaml b/server/service_configs/service_config_circleci.yaml deleted file mode 100644 index ca755dea..00000000 --- a/server/service_configs/service_config_circleci.yaml +++ /dev/null @@ -1,65 +0,0 @@ -# Config for CircleCI - -# Endpoint config -# K8s namespace the endpoints will be created in -endpoint_namespace: llm-engine - -# Asynchronous endpoints -sqs_profile: nonexistent_sqs_profile -sqs_queue_policy_template: > - { - "Version": "2012-10-17", - "Id": "__default_policy_ID", - "Statement": [ - { - "Sid": "__owner_statement", - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:root" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - }, - { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:role/default" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - }, - { - "Effect": "Allow", - "Principal": { - "AWS": "arn:aws:iam::000000000000:role/ml_llm_engine" - }, - "Action": "sqs:*", - "Resource": "arn:aws:sqs:us-west-2:000000000000:${queue_name}" - } - ] - } - -sqs_queue_tag_template: > - { - "infra.scale.com/product": "MLInfraLLMEngineSQS", - "infra.scale.com/team": "${team}", - "infra.scale.com/contact": "yi.xu@scale.com", - "infra.scale.com/customer": "AllCustomers", - "infra.scale.com/financialOwner": "yi.xu@scale.com", - "Spellbook-Serve-Endpoint-Id": "${endpoint_id}", - "Spellbook-Serve-Endpoint-Name": "${endpoint_name}", - "Spellbook-Serve-Endpoint-Created-By": "${endpoint_created_by}" - } - -# resultsS3Bucket (i.e. where HMI will store model inference results) is currently determined on endpoint creation -# via a request - -# modelBundleS3Bucket (i.e. where model bundles are stored) is not determined by any HMI code, but instead -# by some scaleapi routing layer code for scale-hosted HMI, and by request parameters in general. - -# Currently, the celery redis used is defaulted to scale's celery redis, and is hardcoded inside scaleml's celery impl. -# We'll need to bundle this celery implementation along for open-source hosting. - -# There's a separate piece of infra that caches k8s state onto redis, so we need a url to it -cache_redis_url: redis://127.0.0.1:6379/15 -s3_file_llm_fine_tuning_job_repository: "s3://scale-ml-circleci/hosted-model-inference/llm-ft-job-repository/circleci" diff --git a/server/setup.cfg b/server/setup.cfg deleted file mode 100644 index 69610053..00000000 --- a/server/setup.cfg +++ /dev/null @@ -1,18 +0,0 @@ -[aliases] -test=pytest - -[coverage:run] -omit = - llm_engine/start_server.py, - llm_engine/start_service_builder.py - -[tool:pytest] -addopts = - --verbose - --durations=0 - --cache-clear - --cov=llm_engine - --cov-report=term-missing - --mypy - --mypy-ini-file=mypy.ini - --ignore=clients diff --git a/server/setup.py b/server/setup.py deleted file mode 100644 index 5377b054..00000000 --- a/server/setup.py +++ /dev/null @@ -1,19 +0,0 @@ -# To get circleci to work -from setuptools import find_packages, setup - -setup( - name="scale-llm-engine-server", - version="1.0.0", - packages=[p for p in find_packages() if "tests" not in p], - install_requires=[], - entry_points={ - "console_scripts": [ - "start-service-builder=llm_engine_server.start_service_builder:entrypoint", - "start-server=llm_engine_server.start_server:entrypoint", - "start-fastapi-server=llm_engine_server.entrypoints.start_fastapi_server:entrypoint", - "start-batch-job-orchestration=llm_engine_server.entrypoints.start_batch_job_orchestration:entrypoint", - "hosted-inference-server=llm_engine_server.entrypoints.hosted_inference_server:entrypoint", - "autogen=llm_engine_server.scripts.autogenerate_client_and_docs:entrypoint", - ], - }, -) diff --git a/server/tests/README.md b/server/tests/README.md deleted file mode 100644 index e519b96b..00000000 --- a/server/tests/README.md +++ /dev/null @@ -1,7 +0,0 @@ -## To Run Tests: - -```shell -pushd ../ -PYTHONPATH=llm_engine WORKSPACE=. python3 -m pytest llm_engine/tests --cov=llm_engine -popd -``` diff --git a/server/tests/unit/api/test_llms.py b/server/tests/unit/api/test_llms.py deleted file mode 100644 index 9b4065f6..00000000 --- a/server/tests/unit/api/test_llms.py +++ /dev/null @@ -1,171 +0,0 @@ -import json -from typing import Any, Dict, Tuple - -import pytest -from llm_engine_server.common.dtos.llms import GetLLMModelEndpointV1Response -from llm_engine_server.domain.entities import ModelEndpoint - - -def test_create_llm_model_endpoint_success( - create_llm_model_endpoint_request_sync: Dict[str, Any], - test_api_key: str, - get_test_client_wrapper, -): - client = get_test_client_wrapper( - fake_docker_repository_image_always_exists=True, - fake_model_bundle_repository_contents={}, - fake_model_endpoint_record_repository_contents={}, - fake_model_endpoint_infra_gateway_contents={}, - fake_batch_job_record_repository_contents={}, - fake_batch_job_progress_gateway_contents={}, - fake_docker_image_batch_job_bundle_repository_contents={}, - ) - response_1 = client.post( - "/v1/llm/model-endpoints", - auth=(test_api_key, ""), - json=create_llm_model_endpoint_request_sync, - ) - assert response_1.status_code == 200 - - -def test_list_model_endpoints_success( - llm_model_endpoint_async: Tuple[ModelEndpoint, Any], - model_endpoint_2: Tuple[ModelEndpoint, Any], - get_test_client_wrapper, -): - client = get_test_client_wrapper( - fake_model_endpoint_record_repository_contents={ - llm_model_endpoint_async[0].record.id: llm_model_endpoint_async[0].record, - }, - fake_model_endpoint_infra_gateway_contents={ - llm_model_endpoint_async[0] - .infra_state.deployment_name: llm_model_endpoint_async[0] - .infra_state, - model_endpoint_2[0].infra_state.deployment_name: model_endpoint_2[0].infra_state, - }, - ) - response_1 = client.get( - "/v1/llm/model-endpoints?order_by=newest", - auth=("no_user", ""), - ) - expected_model_endpoint_1 = json.loads( - GetLLMModelEndpointV1Response.parse_obj(llm_model_endpoint_async[1]).json() - ) - assert response_1.status_code == 200 - assert response_1.json() == {"model_endpoints": [expected_model_endpoint_1]} - - -def test_get_llm_model_endpoint_success( - llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], - model_endpoint_2: Tuple[ModelEndpoint, Any], - get_test_client_wrapper, -): - client = get_test_client_wrapper( - fake_model_endpoint_record_repository_contents={ - llm_model_endpoint_sync[0].record.id: llm_model_endpoint_sync[0].record, - }, - fake_model_endpoint_infra_gateway_contents={ - llm_model_endpoint_sync[0] - .infra_state.deployment_name: llm_model_endpoint_sync[0] - .infra_state, - model_endpoint_2[0].infra_state.deployment_name: model_endpoint_2[0].infra_state, - }, - ) - response_1 = client.get( - f"/v1/llm/model-endpoints/{llm_model_endpoint_sync[0].record.name}", - auth=("no_user", ""), - ) - expected_model_endpoint_1 = json.loads( - GetLLMModelEndpointV1Response.parse_obj(llm_model_endpoint_sync[1]).json() - ) - assert response_1.status_code == 200 - assert response_1.json() == expected_model_endpoint_1 - - -def test_completion_sync_success( - llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], - completion_sync_request: Dict[str, Any], - get_test_client_wrapper, -): - client = get_test_client_wrapper( - fake_docker_repository_image_always_exists=True, - fake_model_bundle_repository_contents={}, - fake_model_endpoint_record_repository_contents={ - llm_model_endpoint_sync[0].record.id: llm_model_endpoint_sync[0].record, - }, - fake_model_endpoint_infra_gateway_contents={ - llm_model_endpoint_sync[0] - .infra_state.deployment_name: llm_model_endpoint_sync[0] - .infra_state, - }, - fake_batch_job_record_repository_contents={}, - fake_batch_job_progress_gateway_contents={}, - fake_docker_image_batch_job_bundle_repository_contents={}, - ) - response_1 = client.post( - f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}", - auth=("no_user", ""), - json=completion_sync_request, - ) - assert response_1.status_code == 200 - assert response_1.json() == {"outputs": [], "status": "SUCCESS", "traceback": None} - - -def test_completion_sync_raises_temperature_zero( - llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], - completion_sync_request_temperature_zero: Dict[str, Any], - get_test_client_wrapper, -): - client = get_test_client_wrapper( - fake_docker_repository_image_always_exists=True, - fake_model_bundle_repository_contents={}, - fake_model_endpoint_record_repository_contents={ - llm_model_endpoint_sync[0].record.id: llm_model_endpoint_sync[0].record, - }, - fake_model_endpoint_infra_gateway_contents={ - llm_model_endpoint_sync[0] - .infra_state.deployment_name: llm_model_endpoint_sync[0] - .infra_state, - }, - fake_batch_job_record_repository_contents={}, - fake_batch_job_progress_gateway_contents={}, - fake_docker_image_batch_job_bundle_repository_contents={}, - ) - response_1 = client.post( - f"/v1/llm/completions-sync?model_endpoint_name={llm_model_endpoint_sync[0].record.name}", - auth=("no_user", ""), - json=completion_sync_request_temperature_zero, - ) - assert response_1.status_code == 422 - - -@pytest.mark.skip(reason="Need to figure out FastAPI test client asyncio funkiness") -def test_completion_stream_success( - llm_model_endpoint_streaming: ModelEndpoint, - completion_stream_request: Dict[str, Any], - get_test_client_wrapper, -): - client = get_test_client_wrapper( - fake_docker_repository_image_always_exists=True, - fake_model_bundle_repository_contents={}, - fake_model_endpoint_record_repository_contents={ - llm_model_endpoint_streaming.record.id: llm_model_endpoint_streaming.record, - }, - fake_model_endpoint_infra_gateway_contents={ - llm_model_endpoint_streaming.infra_state.deployment_name: llm_model_endpoint_streaming.infra_state, - }, - fake_batch_job_record_repository_contents={}, - fake_batch_job_progress_gateway_contents={}, - fake_docker_image_batch_job_bundle_repository_contents={}, - ) - response_1 = client.post( - f"/v1/llm/completions-stream?model_endpoint_name={llm_model_endpoint_streaming.record.name}", - auth=("no_user", ""), - json=completion_stream_request, - ) - assert response_1.status_code == 200 - count = 0 - for message in response_1: - assert message == b'data: {"status": "SUCCESS", "output": null, "traceback": null}\r\n\r\n' - count += 1 - assert count == 1 diff --git a/server/tests/unit/core/test_env.py b/server/tests/unit/core/test_env.py deleted file mode 100644 index 8765d4ed..00000000 --- a/server/tests/unit/core/test_env.py +++ /dev/null @@ -1,105 +0,0 @@ -import os -from typing import Any, Callable, Dict, Optional, Sequence -from uuid import uuid4 - -import pytest -from llm_engine_server.core.utils.env import environment - -# DO NOT EXPORT ANYTHING -__all__: Sequence[str] = () - - -def expect_not_present(e: str) -> None: - assert ( - e not in os.environ - ), f"Not expecting env var {e} to be present, instead found {os.environ[e]}" - - -def expect_present(e: str, value: Any) -> None: - assert e in os.environ, f"Expecting env var {e} to be present with {value}" - assert ( - os.environ[e] == value - ), f"Expected env var {e} to have value {value} but instead found {os.environ[e]}" - - -def prepare(e: str, existing: Optional[str]) -> Callable[[], None]: - if existing is not None: - os.environ[e] = existing - return lambda: expect_present(e, existing) - else: - return lambda: expect_not_present(e) - - -def test_environment_kwarg(): - e = "ENV_VAR_TEST" - expect_not_present(e) - # NOTE: This is to test keyword argument use. - # Make sure this **literal value** is the same as `e`'s contents. - with environment(ENV_VAR_TEST="x"): - expect_present(e, "x") - expect_not_present(e) - - -@pytest.mark.parametrize("existing", ["env var has prior value", None]) -def test_environment_normal_cases(existing): - e = f"___{uuid4()}-test_env_var" - check = prepare(e, existing) - - check() - new = f"{uuid4()}--hello_world" - with environment(**{e: new}): - expect_present(e, new) - check() - - -@pytest.mark.parametrize("existing", ["env var has prior value", None]) -def test_environment_with_exception(existing): - e = f"___{uuid4()}-test_env_var" - check = prepare(e, existing) - - check() - new = f"{uuid4()}--hello_world" - with pytest.raises(ValueError): - with environment(**{e: new}): - expect_present(e, new) - raise ValueError("Uh oh! Something went wrong in our context!") - check() - - -def test_environment_multi(): - env_vars: Dict[str, str] = {f"___{uuid4()}-test_env_var--{i}": f"value_{i}" for i in range(25)} - - def ok(): - for e in env_vars.keys(): - expect_not_present(e) - - ok() - with environment(**env_vars): - for e, v in env_vars.items(): - expect_present(e, v) - ok() - - -def test_environment_invalid_states(): - with pytest.raises(ValueError): - environment(**{"": "2"}) - - -def test_environment_unset(): - k = f"___{uuid4()}___--test_unset_env_var--" - v = "hello world! :)" - # when there is a previous value - try: - os.environ[k] = v - with environment(**{k: None}): - assert k not in os.environ - assert k in os.environ - assert os.environ[k] == v - finally: - del os.environ[k] - - # when there is not a previous value - assert k not in os.environ - with environment(**{k: None}): - assert k not in os.environ - assert k not in os.environ diff --git a/server/tests/unit/db/common/test_query.py b/server/tests/unit/db/common/test_query.py deleted file mode 100644 index 0d9b173a..00000000 --- a/server/tests/unit/db/common/test_query.py +++ /dev/null @@ -1,18 +0,0 @@ -from dataclasses import dataclass - -from llm_engine_server.db.models.common.query import Query - - -@dataclass -class ExampleQuery(Query): - """ - Example query - """ - - id: str - name: str - - -def test_query(): - query = ExampleQuery(id="123", name="test") - assert query.to_sqlalchemy_query() == {"id": "123", "name": "test"} diff --git a/server/tests/unit/db/common/test_repository.py b/server/tests/unit/db/common/test_repository.py deleted file mode 100644 index c3dfacf5..00000000 --- a/server/tests/unit/db/common/test_repository.py +++ /dev/null @@ -1,69 +0,0 @@ -from unittest.mock import MagicMock, patch - -import pytest -from llm_engine_server.db.models.common.record import Record - - -@pytest.fixture -def mock_session(): - return MagicMock() - - -@pytest.fixture -def mock_query(): - return MagicMock() - - -class TestRecord: - """ - Test the Record class. - """ - - def test_create(self, mock_session): - item = MagicMock() - Record.create(session=mock_session, record=item) - mock_session.add.assert_called_once_with(item) - mock_session.commit.assert_called_once_with() - - @patch("llm_engine_server.db.models.common.record.select") - def test_select_all(self, mock_select, mock_session, mock_query): - mock_query.to_sqlalchemy_query.return_value = {"id": "123", "name": "test"} - mock_select_obj = MagicMock() - mock_select.return_value = mock_select_obj - Record.select_all(session=mock_session, query=mock_query) - mock_select.assert_called_once_with(Record) - mock_select_obj.filter_by.assert_called_once_with(id="123", name="test") - mock_session.execute.assert_called_once_with(mock_select_obj.filter_by.return_value) - mock_session.execute.return_value.scalars.assert_called_once_with() - mock_session.execute.return_value.scalars.return_value.all.assert_called_once_with() - - @patch("llm_engine_server.db.models.common.record.select") - def test_select_by_id(self, mock_select, mock_session): - mock_select_obj = MagicMock() - mock_select.return_value = mock_select_obj - Record.select_by_id(session=mock_session, record_id="123") - mock_select.assert_called_once_with(Record) - mock_select_obj.filter_by.assert_called_once_with(id="123") - mock_session.execute.assert_called_once_with(mock_select_obj.filter_by.return_value) - mock_session.execute.return_value.scalar_one_or_none.assert_called_once_with() - - @patch("llm_engine_server.db.models.common.record.select") - def test_update(self, mock_select, mock_session, mock_query): - mock_select_obj = MagicMock() - mock_select.return_value = mock_select_obj - mock_query.to_sqlalchemy_query.return_value = {"name": "test"} - item = MagicMock() - mock_session.execute.return_value.scalar_one_or_none.return_value = item - Record.update(session=mock_session, record_id="123", query=mock_query) - mock_select.assert_called_once_with(Record) - mock_select_obj.filter_by.assert_called_once_with(id="123") - mock_session.execute.assert_called_once_with(mock_select_obj.filter_by.return_value) - mock_session.execute.return_value.scalar_one_or_none.assert_called_once_with() - item.name = "test" - mock_session.commit.assert_called_once_with() - - def test_delete(self, mock_session): - item = MagicMock() - Record.delete(session=mock_session, record=item) - mock_session.delete.assert_called_once_with(item) - mock_session.commit.assert_called_once_with() diff --git a/server/tests/unit/db/conftest.py b/server/tests/unit/db/conftest.py deleted file mode 100644 index 1a01c532..00000000 --- a/server/tests/unit/db/conftest.py +++ /dev/null @@ -1,467 +0,0 @@ -import datetime -import os -from typing import List - -import psycopg2 -import pytest -import pytest_asyncio -import testing.postgresql -from llm_engine_server.db.base import Session, SessionAsync -from llm_engine_server.db.local_setup import init_database, init_database_and_engine -from llm_engine_server.db.models import ( - BatchJob, - Bundle, - DockerImageBatchJobBundle, - Endpoint, - Model, - ModelArtifact, - ModelVersion, -) -from sqlalchemy import create_engine -from sqlalchemy.engine import Engine -from sqlalchemy.ext.asyncio import create_async_engine - - -def init_testing_postgresql(postgresql: testing.postgresql.Postgresql) -> None: - """Initializes local postgresql server.""" - conn = psycopg2.connect(**postgresql.dsn()) - init_database(postgresql.url(), conn) # type: ignore - - -@pytest.fixture(scope="session") -def engine() -> Engine: - if os.getenv("ML_INFRA_DATABASE_URL"): - url = os.getenv("ML_INFRA_DATABASE_URL") - db_engine = init_database_and_engine(url) - yield db_engine - else: - Postgresql = testing.postgresql.PostgresqlFactory( - cache_initialized_db=True, - on_initialized=init_testing_postgresql, - ) - postgresql = Postgresql().__enter__() - yield create_engine(postgresql.url(), echo=False, future=True) - - -@pytest.fixture(scope="function") -def dbsession(engine: Engine) -> Session: - """Returns a sqlalchemy session, and after the test tears down everything properly.""" - connection = engine.connect() - transaction = connection.begin() - session = Session(bind=connection) - - yield session - - session.close() - transaction.rollback() - connection.close() - - -@pytest_asyncio.fixture(scope="function") -async def dbsession_async(engine: Engine) -> SessionAsync: - """Returns a sqlalchemy session, and after the test tears down everything properly.""" - url = str(engine.url).replace("postgresql://", "postgresql+asyncpg://") - engine = create_async_engine(url) - async with engine.connect() as connection: - async with connection.begin() as transaction: - session = SessionAsync(bind=connection) - yield session - await session.close() - await transaction.rollback() - await connection.close() - - -@pytest_asyncio.fixture(scope="function") -async def bundles(dbsession_async: SessionAsync) -> List[Bundle]: - bundle1 = Bundle( - name="test_bundle_1", - created_by="test_user_1", - model_artifact_ids=None, - schema_location=None, - owner="test_user_1", - flavor="cloudpickle_artifact", - # Artifact fields - artifact_requirements=["test_requirement_1"], - artifact_location="test_location_1", - artifact_app_config=None, - artifact_framework_type="pytorch", - artifact_pytorch_image_tag="test_tag_1", - # Cloudpickle artifact fields - cloudpickle_artifact_load_predict_fn="test_load_predict_fn", - cloudpickle_artifact_load_model_fn="test_load_model_fn", - # Legacy fields - location="test_location_1", - version="v0", - registered_model_name="registered_model_name_1", - bundle_metadata=None, - env_params=None, - packaging_type=None, - app_config=None, - ) - bundle2 = Bundle( - name="test_bundle_2", - created_by="test_user_1", - model_artifact_ids=None, - schema_location=None, - owner="test_user_1", - flavor="zip_artifact", - # Artifact fields - artifact_requirements=["test_requirement_1"], - artifact_location="test_location_2", - artifact_app_config={"test_key": "test_value"}, - artifact_framework_type="custom_base_image", - artifact_image_repository="test_repo_1", - artifact_image_tag="test_tag_1", - # Zip artifact fields - zip_artifact_load_predict_fn_module_path="test_path_1", - zip_artifact_load_model_fn_module_path="test_path_2", - # Legacy fields - location="test_location_1", - version="v0", - registered_model_name="registered_model_name_1", - bundle_metadata=None, - env_params=None, - packaging_type=None, - app_config=None, - ) - bundle3 = Bundle( - name="test_bundle_3", - created_by="test_user_2", - model_artifact_ids=None, - schema_location=None, - owner="test_user_1", - flavor="runnable_image", - # Runnable Image fields - runnable_image_repository="test_repository_1", - runnable_image_tag="test_tag_1", - runnable_image_command=["test_command_1"], - runnable_image_predict_route="/test_predict_route", - runnable_image_healthcheck_route="/test_healthcheck_route", - runnable_image_env={"test_key": "test_value"}, - runnable_image_protocol="http", - runnable_image_readiness_initial_delay_seconds=300, - # Legacy fields - location="test_location_1", - version="v0", - registered_model_name="registered_model_name_1", - bundle_metadata=None, - env_params=None, - packaging_type=None, - app_config=None, - ) - bundle4 = Bundle( - name="test_bundle_4", - created_by="test_user_2", - model_artifact_ids=None, - schema_location=None, - owner="test_user_1", - flavor="triton_enhanced_runnable_image", - # Runnable Image fields - runnable_image_repository="test_repository_1", - runnable_image_tag="test_tag_1", - runnable_image_command=["test_command_1"], - runnable_image_predict_route="/test_predict_route", - runnable_image_healthcheck_route="/test_healthcheck_route", - runnable_image_env={"test_key": "test_value"}, - runnable_image_protocol="http", - runnable_image_readiness_initial_delay_seconds=300, - # Triton enhanced runnable image fields - triton_enhanced_runnable_image_model_repository="test_model_repository_1", - triton_enhanced_runnable_image_model_replicas={"test_model_1": "test_val"}, - triton_enhanced_runnable_image_num_cpu=3.5, - triton_enhanced_runnable_image_commit_tag="test_commit_tag_1", - triton_enhanced_runnable_image_storage="test_storage_1", - triton_enhanced_runnable_image_readiness_initial_delay_seconds=350, - # Legacy fields - location="test_location_1", - version="v0", - registered_model_name="registered_model_name_1", - bundle_metadata=None, - env_params=None, - packaging_type=None, - app_config=None, - ) - bundle5 = Bundle( - name="test_bundle_5", - created_by="test_user_2", - model_artifact_ids=None, - schema_location=None, - owner="test_user_1", - flavor="streaming_enhanced_runnable_image", - # Runnable Image fields - runnable_image_repository="test_repository_1", - runnable_image_tag="test_tag_1", - runnable_image_command=["test_command_1"], - runnable_image_predict_route="/test_predict_route", - runnable_image_healthcheck_route="/test_healthcheck_route", - runnable_image_env={"test_key": "test_value"}, - runnable_image_protocol="http", - runnable_image_readiness_initial_delay_seconds=300, - # Streaming Enhanced Runnable Image fields - streaming_enhanced_runnable_image_streaming_command=["test_streaming_command_1"], - streaming_enhanced_runnable_image_streaming_predict_route="/test_streaming_predict_route", - # Legacy fields - location="test_location_1", - version="v0", - registered_model_name="registered_model_name_1", - bundle_metadata=None, - env_params=None, - packaging_type=None, - app_config=None, - ) - bundles = [bundle1, bundle2, bundle3, bundle4, bundle5] - for bundle in bundles: - await Bundle.create(dbsession_async, bundle) - return bundles - - -@pytest_asyncio.fixture(scope="function") -async def endpoints(dbsession_async: SessionAsync, bundles: List[Bundle]) -> List[Endpoint]: - endpoint1 = Endpoint( - name="test_endpoint_1", - created_by="test_user_1", - current_bundle_id=bundles[0].id, - endpoint_metadata=None, - creation_task_id="test_creation_task_id_1", - endpoint_type="async", - destination="test_destination_1", - endpoint_status="READY", - owner="test_user_1", - ) - endpoint2 = Endpoint( - name="test_endpoint_2", - created_by="test_user_1", - current_bundle_id=bundles[0].id, - endpoint_metadata=None, - creation_task_id="test_creation_task_id_1", - endpoint_type="async", - destination="test_destination_1", - endpoint_status="READY", - owner="test_user_1", - ) - endpoint3 = Endpoint( - name="test_endpoint_3", - created_by="test_user_1", - current_bundle_id=bundles[1].id, - endpoint_metadata=None, - creation_task_id="test_creation_task_id_1", - endpoint_type="async", - destination="test_destination_1", - endpoint_status="READY", - owner="test_user_1", - ) - endpoints = [endpoint1, endpoint2, endpoint3] - for endpoint in endpoints: - await Endpoint.create(dbsession_async, endpoint) - return endpoints - - -@pytest_asyncio.fixture(scope="function") -async def batch_jobs( - dbsession_async: SessionAsync, bundles: List[Bundle], endpoints: List[Endpoint] -) -> List[BatchJob]: - batch_job1 = BatchJob( - batch_job_status="READY", - created_by="test_user_1", - owner="test_user_1", - model_bundle_id=bundles[0].id, - model_endpoint_id=endpoints[0].id, - task_ids_location=None, - ) - batch_job2 = BatchJob( - batch_job_status="READY", - created_by="test_user_1", - owner="test_user_1", - model_bundle_id=bundles[0].id, - model_endpoint_id=endpoints[0].id, - task_ids_location=None, - ) - batch_job3 = BatchJob( - batch_job_status="READY", - created_by="test_user_2", - owner="test_user_2", - model_bundle_id=bundles[1].id, - model_endpoint_id=endpoints[2].id, - task_ids_location=None, - ) - jobs = [batch_job1, batch_job2, batch_job3] - for batch_job in jobs: - await BatchJob.create(dbsession_async, batch_job) - return jobs - - -@pytest_asyncio.fixture(scope="function") -async def docker_image_batch_job_bundles( - dbsession_async: SessionAsync, -) -> List[DockerImageBatchJobBundle]: - batch_bundle_1 = DockerImageBatchJobBundle( - name="test_docker_image_batch_job_bundle_1", - created_by="test_user_1", - owner="test_user_1", - image_repository="image_repository", - image_tag="image_tag_git_sha", - command=["python", "script.py", "--arg1"], - env=dict(ENV1="VAL1", ENV2="VAL2"), - mount_location="/mount/location/to/config", - cpus="1", - memory=None, - storage=None, - gpus=None, - gpu_type=None, - public=None, - ) - batch_bundle_2 = DockerImageBatchJobBundle( - name="test_docker_image_batch_job_bundle_1", - created_by="test_user_1", - owner="test_user_1", - image_repository="image_repository", - image_tag="image_tag_git_sha", - command=["python", "script.py", "--arg2"], - env=dict(ENV1="VAL3", ENV2="VAL4"), - mount_location="/mount/location/to/config2", - cpus="2", - memory=None, - storage=None, - gpus=None, - gpu_type=None, - public=None, - ) - batch_bundle_3 = DockerImageBatchJobBundle( - name="test_docker_image_batch_job_bundle_2", - created_by="test_user_2", - owner="test_user_2", - image_repository="image_repository", - image_tag="image_tag_git_sha", - command=["python", "script2.py", "--arg1"], - env=dict(ENV1="VAL1", ENV2="VAL2"), - mount_location="/mount2/location/to/config", - cpus="3", - memory=None, - storage=None, - gpus=None, - gpu_type=None, - public=None, - ) - batch_bundle_1.created_at = datetime.datetime(2022, 1, 1) - batch_bundle_2.created_at = datetime.datetime(2022, 1, 3) - batch_bundle_3.created_at = datetime.datetime(2022, 1, 2) - batch_bundles = [batch_bundle_1, batch_bundle_2, batch_bundle_3] - for batch_bundle in batch_bundles: - await DockerImageBatchJobBundle.create(dbsession_async, batch_bundle) - return batch_bundles - - -@pytest.fixture(scope="function") -def models(dbsession: Session) -> List[Model]: - model1 = Model( - name="test_model_1", - description="test_description_1", - task_types=["test_task_type_1", "test_task_type_2"], - created_by="test_user_id_1", - owner="test_user_id_1", - ) - model2 = Model( - name="test_model_2", - description="test_description_2", - task_types=["test_task_type_1", "test_task_type_3"], - created_by="test_user_id_1", - owner="test_user_id_1", - ) - model3 = Model( - name="test_model_1", - description="test_description_1", - task_types=["test_task_type_2", "test_task_type_3"], - created_by="test_user_id_2", - owner="test_user_id_2", - ) - models = [model1, model2, model3] - for model in models: - Model.create(dbsession, model) - return models - - -@pytest_asyncio.fixture(scope="function") -async def model_versions( - dbsession: Session, models: List[Model], bundles: List[Bundle] -) -> List[ModelVersion]: - model_version1 = ModelVersion( - model_id=models[0].id, - version_number=0, - tags=["test_tag_1", "test_tag_2"], - metadata={"key1": "value1"}, - created_by="test_user_id_1", - ) - model_version2 = ModelVersion( - model_id=models[0].id, - version_number=1, - llm_engine_model_bundle_id=bundles[0].id, - tags=["test_tag_1", "test_tag_3"], - metadata={"key1": "value2"}, - created_by="test_user_id_1", - ) - model_version3 = ModelVersion( - model_id=models[2].id, - version_number=0, - llm_engine_model_bundle_id=bundles[1].id, - nucleus_model_id="test_nucleus_model_id_1", - tags=["test_tag_1", "test_tag_2"], - metadata={"key2": "value3"}, - created_by="test_user_id_1", - ) - model_versions = [model_version1, model_version2, model_version3] - for model_version in model_versions: - ModelVersion.create(dbsession, model_version) - return model_versions - - -@pytest.fixture(scope="function") -def model_artifacts(dbsession: Session) -> List[ModelArtifact]: - model_artifact1 = ModelArtifact( - name="test_model_artifact_1", - description="test_description_1", - is_public=True, - created_by="test_user_id_1", - owner="test_user_id_1", - input_schema={"test_schema_key": "test_schema_value"}, - output_schema={"test_schema_key": "test_schema_value"}, - config={"test_config_key": "test_config_value"}, - location="test_location", - format="pytorch", - format_metadata={"test_format_key": "test_format_value"}, - source="huggingface", - source_metadata={"test_source_key": "test_source_value"}, - ) - model_artifact2 = ModelArtifact( - name="test_model_artifact_2", - description="test_description_2", - is_public=False, - created_by="test_user_id_1", - owner="test_user_id_1", - input_schema={"test_schema_key": "test_schema_value"}, - output_schema={"test_schema_key": "test_schema_value"}, - config={"test_config_key": "test_config_value"}, - location="test_location", - format="pytorch", - format_metadata={"test_format_key": "test_format_value"}, - source="huggingface", - source_metadata={"test_source_key": "test_source_value"}, - ) - model_artifact3 = ModelArtifact( - name="test_model_artifact_3", - description="test_description_3", - is_public=True, - created_by="test_user_id_2", - owner="test_user_id_2", - input_schema={"test_schema_key": "test_schema_value"}, - output_schema={"test_schema_key": "test_schema_value"}, - config={"test_config_key": "test_config_value"}, - location="test_location", - format="tensorflow", - format_metadata={"test_format_key": "test_format_value"}, - source="mlflow", - source_metadata={"test_source_key": "test_source_value"}, - ) - model_artifacts = [model_artifact1, model_artifact2, model_artifact3] - for model_artifact in model_artifacts: - ModelArtifact.create(dbsession, model_artifact) - return model_artifacts diff --git a/server/tests/unit/db/test_endpoint_row_lock.py b/server/tests/unit/db/test_endpoint_row_lock.py deleted file mode 100644 index ed7879e0..00000000 --- a/server/tests/unit/db/test_endpoint_row_lock.py +++ /dev/null @@ -1,22 +0,0 @@ -# Since the bulk of the file involves actually connecting to postgres, we're only gonna test that the -# `get_lock_key` function doesn't error and returns nonnegative ints from 0 to 2**64-1 - -from llm_engine_server.db.base import Session -from llm_engine_server.db.endpoint_row_lock import AdvisoryLockContextManager, get_lock_key - - -def test_get_lock_key(): - pairs = [ - ("userid1", "endpointname1"), - ("userid2", "endpointname2"), - ("userid", "1endpointname1"), - ("endpointname1", "userid1"), - ] + [(str(i), str(i)) for i in range(10000)] - keys = [get_lock_key(uid, name) for uid, name in pairs] - assert len(keys) == len(set(keys)), "Key collision found" - assert all([-(2**63) <= key < 2**63 for key in keys]), "Key falls outside of range" - - -def test_lock_context_manager(dbsession: Session): - with AdvisoryLockContextManager(session=dbsession, lock_id=10) as lock: - assert lock.lock_acquired() diff --git a/server/tests/unit/db/test_llm_engine.py b/server/tests/unit/db/test_llm_engine.py deleted file mode 100644 index 51811464..00000000 --- a/server/tests/unit/db/test_llm_engine.py +++ /dev/null @@ -1,160 +0,0 @@ -from datetime import datetime -from typing import List - -import pytest -from llm_engine_server.db.base import SessionAsync -from llm_engine_server.db.models import BatchJob, Bundle, DockerImageBatchJobBundle, Endpoint - - -@pytest.mark.asyncio -async def test_bundle_select(dbsession_async: SessionAsync, bundles: List[Bundle]): - bundle_by_name_created_by = await Bundle.select_by_name_created_by( - dbsession_async, name="test_bundle_1", created_by="test_user_1" - ) - assert bundle_by_name_created_by is not None - - bundle_by_name_owner = await Bundle.select_by_name_owner( - dbsession_async, name="test_bundle_1", owner="test_user_1" - ) - assert bundle_by_name_owner is not None - - bundles_by_name_created_by = await Bundle.select_all_by_name_created_by( - dbsession_async, name="test_bundle_1", created_by="test_user_1" - ) - assert len(bundles_by_name_created_by) == 1 - - bundles_by_name_owner = await Bundle.select_all_by_name_owner( - dbsession_async, name="test_bundle_1", owner="test_user_1" - ) - assert len(bundles_by_name_owner) == 1 - - bundle_by_id = await Bundle.select_by_id(dbsession_async, bundle_id=bundles[0].id) - assert bundle_by_id is not None - - bundles_by_owner = await Bundle.select_all_by_created_by( - dbsession_async, created_by="test_user_1" - ) - assert len(bundles_by_owner) == 2 - - -@pytest.mark.asyncio -async def test_bundle_select_delete(dbsession_async: SessionAsync, bundles: List[Bundle]): - bundles_by_owner = await Bundle.select_all_by_created_by( - dbsession_async, created_by="test_user_1" - ) - prev_num_bundles = len(bundles_by_owner) - - await Bundle.delete(dbsession_async, bundles_by_owner[0]) - - # After deletion, there should now be 1 fewer bundles for this user. - bundles_by_owner = await Bundle.select_all_by_created_by( - dbsession_async, created_by="test_user_1" - ) - assert len(bundles_by_owner) == prev_num_bundles - 1 - - -@pytest.mark.asyncio -async def test_endpoint_select( - dbsession_async: SessionAsync, bundles: List[Bundle], endpoints: List[Endpoint] -): - endpoint_by_name_created_by = await Endpoint.select_by_name_created_by( - dbsession_async, name="test_endpoint_1", created_by="test_user_1" - ) - assert endpoint_by_name_created_by is not None - - endpoints_by_created_by = await Endpoint.select_all_by_created_by( - dbsession_async, created_by="test_user_1" - ) - assert len(endpoints_by_created_by) == 3 - - endpoints_by_owner = await Endpoint.select_all_by_owner(dbsession_async, owner="test_user_1") - assert len(endpoints_by_owner) == 3 - - endpoints_by_bundle_owner = await Endpoint.select_all_by_bundle_created_by( - dbsession_async, current_bundle_id=bundles[0].id, created_by="test_user_1" - ) - assert len(endpoints_by_bundle_owner) == 2 - - -@pytest.mark.asyncio -async def test_endpoint_select_delete( - dbsession_async: SessionAsync, bundles: List[Bundle], endpoints: List[Endpoint] -): - endpoints_by_user_id = await Endpoint.select_all_by_created_by( - dbsession_async, created_by="test_user_1" - ) - prev_num_endpoints = len(endpoints_by_user_id) - - await Endpoint.delete(dbsession_async, endpoints_by_user_id[0]) - - # After deletion, there should now be 1 fewer endpoints for this user. - endpoints_by_user_id = await Endpoint.select_all_by_created_by( - dbsession_async, created_by="test_user_1" - ) - assert len(endpoints_by_user_id) == prev_num_endpoints - 1 - - -@pytest.mark.asyncio -async def test_batch_job_select(dbsession_async: SessionAsync, batch_jobs: List[BatchJob]): - batch_job_by_id = await BatchJob.select_by_id(dbsession_async, batch_job_id=batch_jobs[0].id) - assert batch_job_by_id is not None - - batch_jobs_by_owner = await BatchJob.select_all_by_owner(dbsession_async, owner="test_user_1") - assert len(batch_jobs_by_owner) == 2 - - batch_jobs_by_owner = await BatchJob.select_all_by_bundle_owner( - dbsession_async, - model_bundle_id=batch_jobs[0].model_bundle_id, - owner="test_user_1", - ) - assert len(batch_jobs_by_owner) == 2 - - -@pytest.mark.asyncio -async def test_batch_job_update(dbsession_async: SessionAsync, batch_jobs: List[BatchJob]): - update_kwargs = {"status": "FAILED", "completed_at": datetime.now()} - await BatchJob.update_by_id( - session=dbsession_async, batch_job_id=batch_jobs[0].id, kwargs=update_kwargs - ) - batch_job = await BatchJob.select_by_id(dbsession_async, batch_job_id=batch_jobs[0].id) - assert batch_job is not None - assert batch_job.batch_job_status == update_kwargs["status"] - assert batch_job.completed_at.second == update_kwargs["completed_at"].second # type: ignore - - -@pytest.mark.asyncio -async def test_docker_image_batch_job_bundle_select( - dbsession_async: SessionAsync, - docker_image_batch_job_bundles: List[DockerImageBatchJobBundle], -): - batch_job_by_id = await DockerImageBatchJobBundle.select_by_id( - dbsession_async, batch_bundle_id=docker_image_batch_job_bundles[0].id - ) - assert batch_job_by_id is not None - - batch_jobs_by_owner = await DockerImageBatchJobBundle.select_all_by_owner( - dbsession_async, owner="test_user_1" - ) - assert len(batch_jobs_by_owner) == 2 - - batch_jobs_by_owner = await DockerImageBatchJobBundle.select_all_by_name_owner( - dbsession_async, - name=docker_image_batch_job_bundles[0].name, - owner="test_user_1", - ) - assert len(batch_jobs_by_owner) == 2 - - batch_jobs_by_owner = await DockerImageBatchJobBundle.select_all_by_name_owner( - dbsession_async, - name=docker_image_batch_job_bundles[2].name, - owner="test_user_2", - ) - assert len(batch_jobs_by_owner) == 1 - - batch_job_latest_by_name_owner = await DockerImageBatchJobBundle.select_latest_by_name_owner( - dbsession_async, - name=docker_image_batch_job_bundles[0].name, - owner="test_user_1", - ) - assert batch_job_latest_by_name_owner is not None - assert batch_job_latest_by_name_owner.id == docker_image_batch_job_bundles[1].id diff --git a/server/tests/unit/db/test_model.py b/server/tests/unit/db/test_model.py deleted file mode 100644 index 1bd6fab3..00000000 --- a/server/tests/unit/db/test_model.py +++ /dev/null @@ -1,140 +0,0 @@ -from typing import List - -from llm_engine_server.db.base import Session -from llm_engine_server.db.models import Bundle, Model, ModelArtifact, ModelVersion - - -def test_model_select(dbsession: Session, models: List[Model]): - models_by_owner = Model.select(dbsession, owner="test_user_id_1") - assert len(models_by_owner) == 2 - - models_by_name = Model.select(dbsession, owner="test_user_id_1", name="test_model_1") - assert len(models_by_name) == 1 - - models_by_created_by = Model.select( - dbsession, owner="test_user_id_1", created_by="test_user_id_1" - ) - assert len(models_by_created_by) == 2 - - models_by_task_types = Model.select( - dbsession, owner="test_user_id_1", task_types=["test_task_type_1"] - ) - assert len(models_by_task_types) == 2 - - model_by_id = Model.select_by_id(dbsession, model_id=models[0].id) - assert model_by_id is not None - - -def test_model_update(dbsession: Session, models: List[Model]): - Model.update_by_id(dbsession, models[0].id, description="new description") - model = Model.select_by_id(dbsession, models[0].id) - assert model is not None - assert model.description == "new description" - - -def test_model_version_select( - dbsession: Session, models: List[Model], model_versions: List[ModelVersion] -): - model_versions_by_owner = ModelVersion.select(dbsession, owner="test_user_id_1") - assert len(model_versions_by_owner) == 2 - - model_versions_by_model_id = ModelVersion.select( - dbsession, owner="test_user_id_1", model_id=models[0].id - ) - assert len(model_versions_by_model_id) == 2 - - model_versions_by_model_name = ModelVersion.select( - dbsession, owner="test_user_id_1", model_name="test_model_1" - ) - assert len(model_versions_by_model_name) == 2 - - model_versions_by_tags = ModelVersion.select( - dbsession, owner="test_user_id_1", tags=["test_tag_1"] - ) - assert len(model_versions_by_tags) == 2 - - model_version_by_id = ModelVersion.select_by_id( - dbsession, model_version_id=model_versions[0].id - ) - assert model_version_by_id is not None - - -def test_model_version_select_by_model_id( - dbsession: Session, - bundles: List[Bundle], - models: List[Model], - model_versions: List[ModelVersion], -): - model_version_by_bundle_id = ModelVersion.select_by_llm_engine_model_bundle_id( - dbsession, bundles[0].id - ) - assert model_version_by_bundle_id is not None - assert model_version_by_bundle_id.llm_engine_model_bundle_id == bundles[0].id - - model_version_by_nucleus_model_id = ModelVersion.select_by_nucleus_model_id( - dbsession, model_versions[2].nucleus_model_id # type: ignore - ) - assert model_version_by_nucleus_model_id is not None - - -def test_model_version_get_highest_version_number( - dbsession: Session, models: List[Model], model_versions: List[ModelVersion] -): - version_number = ModelVersion.get_highest_version_number_for_model( - dbsession, - model_id=models[0].id, - ) - assert version_number == 1 - - version_number = ModelVersion.get_highest_version_number_for_model( - dbsession, - model_id=models[1].id, - ) - assert version_number is None - - version_number = ModelVersion.get_highest_version_number_for_model( - dbsession, - model_id="unknown id", - ) - assert version_number is None - - -def test_model_version_update( - dbsession: Session, models: List[Model], model_versions: List[ModelVersion] -): - ModelVersion.update_by_id( - dbsession, model_versions[0].id, nucleus_model_id="test_nucleus_model_id_upd" - ) - model_version = ModelVersion.select_by_id(dbsession, model_versions[0].id) - assert model_version is not None - assert model_version.nucleus_model_id == "test_nucleus_model_id_upd" - - -def test_model_artifact_select(dbsession: Session, model_artifacts: List[ModelArtifact]): - model_artifacts_by_owner = ModelArtifact.select(dbsession, owner="test_user_id_1") - assert len(model_artifacts_by_owner) == 3 - - model_artifacts_by_no_owner = ModelArtifact.select(dbsession) - assert len(model_artifacts_by_no_owner) == 2 - - model_artifacts_by_name = ModelArtifact.select( - dbsession, owner="test_user_id_1", name="test_model_artifact_1" - ) - assert len(model_artifacts_by_name) == 1 - - model_artifacts_by_created_by = ModelArtifact.select( - dbsession, owner="test_user_id_1", created_by="test_user_id_1" - ) - assert len(model_artifacts_by_created_by) == 2 - - model_artifact_by_id = ModelArtifact.select_by_id( - dbsession, model_artifact_id=model_artifacts[0].id - ) - assert model_artifact_by_id is not None - - -def test_model_artifact_update(dbsession: Session, model_artifacts: List[ModelArtifact]): - ModelArtifact.update_by_id(dbsession, model_artifacts[0].id, description="new description") - updated_model_artifact = ModelArtifact.select_by_id(dbsession, model_artifacts[0].id) - assert updated_model_artifact is not None - assert updated_model_artifact.description == "new description" diff --git a/server/tests/unit/domain/conftest.py b/server/tests/unit/domain/conftest.py deleted file mode 100644 index 5c993d0b..00000000 --- a/server/tests/unit/domain/conftest.py +++ /dev/null @@ -1,348 +0,0 @@ -import pytest -from llm_engine_server.common.dtos.batch_jobs import ( - CreateDockerImageBatchJobBundleV1Request, - CreateDockerImageBatchJobResourceRequests, -) -from llm_engine_server.common.dtos.llms import ( - CompletionStreamV1Request, - CompletionSyncV1Request, - CreateLLMModelEndpointV1Request, -) -from llm_engine_server.common.dtos.model_bundles import ( - CreateModelBundleV1Request, - CreateModelBundleV2Request, -) -from llm_engine_server.common.dtos.model_endpoints import ( - CreateModelEndpointV1Request, - UpdateModelEndpointV1Request, -) -from llm_engine_server.domain.entities import ( - GpuType, - ModelBundle, - ModelBundleEnvironmentParams, - ModelBundleFrameworkType, - ModelBundlePackagingType, - ModelEndpointType, - Quantization, - StreamingEnhancedRunnableImageFlavor, -) - - -@pytest.fixture -def create_model_bundle_request() -> CreateModelBundleV1Request: - env_params = ModelBundleEnvironmentParams( - framework_type=ModelBundleFrameworkType.CUSTOM, - ecr_repo="test_repo", - image_tag="test_tag", - ) - return CreateModelBundleV1Request( - name="test_bundle_name", - location="test_location", - requirements=["numpy==0.0.0"], - env_params=env_params, - packaging_type=ModelBundlePackagingType.CLOUDPICKLE, - metadata=None, - app_config=None, - ) - - -@pytest.fixture -def create_model_bundle_v2_request() -> CreateModelBundleV2Request: - return CreateModelBundleV2Request( - name="test_bundle_name", - metadata=None, - schema_location="s3://test-bucket/test-key", - flavor=StreamingEnhancedRunnableImageFlavor( - flavor="streaming_enhanced_runnable_image", - repository="test_repo", - tag="test_tag", - command=["test_command"], - env={"test_key": "test_value"}, - protocol="http", - readiness_initial_delay_seconds=30, - streaming_command=["test_streaming_command"], - streaming_predict_route="/test_streaming_predict_route", - ), - ) - - -@pytest.fixture -def create_model_endpoint_request_sync( - model_bundle_1: ModelBundle, -) -> CreateModelEndpointV1Request: - return CreateModelEndpointV1Request( - name="test_endpoint_name_1", - model_bundle_id=model_bundle_1.id, - endpoint_type=ModelEndpointType.SYNC, - metadata={}, - post_inference_hooks=[], - cpus=1, - gpus=1, - memory="8G", - gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, - min_workers=1, - max_workers=3, - per_worker=2, - labels={"team": "infra", "product": "my_product"}, - aws_role="test_aws_role", - results_s3_bucket="test_s3_bucket", - ) - - -@pytest.fixture -def create_model_endpoint_request_streaming( - model_bundle_5: ModelBundle, -) -> CreateModelEndpointV1Request: - return CreateModelEndpointV1Request( - name="test_endpoint_name_2", - model_bundle_id=model_bundle_5.id, - endpoint_type=ModelEndpointType.STREAMING, - metadata={}, - post_inference_hooks=[], - cpus=1, - gpus=1, - memory="8G", - gpu_type=GpuType.NVIDIA_TESLA_T4, - storage="10G", - min_workers=1, - max_workers=3, - per_worker=1, - labels={"team": "infra", "product": "my_product"}, - aws_role="test_aws_role", - results_s3_bucket="test_s3_bucket", - ) - - -@pytest.fixture -def create_model_endpoint_request_async( - model_bundle_1: ModelBundle, -) -> CreateModelEndpointV1Request: - return CreateModelEndpointV1Request( - name="test_endpoint_name_2", - model_bundle_id=model_bundle_1.id, - endpoint_type=ModelEndpointType.ASYNC, - metadata={}, - post_inference_hooks=[], - cpus=1, - gpus=1, - memory="8G", - gpu_type=GpuType.NVIDIA_TESLA_T4, - storage="10G", - min_workers=1, - max_workers=3, - per_worker=2, - labels={"team": "infra", "product": "my_product"}, - aws_role="test_aws_role", - results_s3_bucket="test_s3_bucket", - ) - - -@pytest.fixture -def update_model_endpoint_request( - model_bundle_2: ModelBundle, -) -> UpdateModelEndpointV1Request: - return UpdateModelEndpointV1Request( - model_bundle_id=model_bundle_2.id, - metadata={"test_new_key": "test_new_value"}, - cpus=2, - memory="16G", - min_workers=1, - max_workers=4, - per_worker=2, - ) - - -@pytest.fixture -def create_docker_image_batch_job_bundle_request() -> CreateDockerImageBatchJobBundleV1Request: - return CreateDockerImageBatchJobBundleV1Request( - name="name", - image_repository="repo", - image_tag="tag", - command=["sudo", "rn", "-rf"], - env=dict(hi="hi", bye="bye"), - mount_location=None, - resource_requests=CreateDockerImageBatchJobResourceRequests( - cpus=1, memory=None, storage=None, gpus=None, gpu_type=None - ), - ) - - -@pytest.fixture -def create_llm_model_endpoint_request_sync() -> CreateLLMModelEndpointV1Request: - return CreateLLMModelEndpointV1Request( - name="test_llm_endpoint_name_sync", - model_name="mpt-7b", - source="hugging_face", - inference_framework="deepspeed", - inference_framework_image_tag="test_tag", - num_shards=2, - endpoint_type=ModelEndpointType.SYNC, - metadata={}, - post_inference_hooks=[], - cpus=1, - gpus=2, - memory="8G", - gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, - min_workers=1, - max_workers=3, - per_worker=2, - labels={"team": "infra", "product": "my_product"}, - aws_role="test_aws_role", - results_s3_bucket="test_s3_bucket", - ) - - -@pytest.fixture -def create_llm_model_endpoint_request_async() -> CreateLLMModelEndpointV1Request: - return CreateLLMModelEndpointV1Request( - name="test_llm_endpoint_name_async", - model_name="mpt-7b", - source="hugging_face", - inference_framework="deepspeed", - inference_framework_image_tag="test_tag", - num_shards=2, - endpoint_type=ModelEndpointType.ASYNC, - metadata={}, - post_inference_hooks=[], - cpus=1, - gpus=2, - memory="8G", - gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, - min_workers=0, - max_workers=3, - per_worker=2, - labels={"team": "infra", "product": "my_product"}, - aws_role="test_aws_role", - results_s3_bucket="test_s3_bucket", - ) - - -@pytest.fixture -def create_llm_model_endpoint_request_streaming() -> CreateLLMModelEndpointV1Request: - return CreateLLMModelEndpointV1Request( - name="test_llm_endpoint_name_streaming", - model_name="mpt-7b", - source="hugging_face", - inference_framework="deepspeed", - inference_framework_image_tag="test_tag", - num_shards=2, - endpoint_type=ModelEndpointType.STREAMING, - metadata={}, - post_inference_hooks=[], - cpus=1, - gpus=2, - memory="8G", - gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, - min_workers=1, - max_workers=3, - per_worker=2, - labels={"team": "infra", "product": "my_product"}, - aws_role="test_aws_role", - results_s3_bucket="test_s3_bucket", - ) - - -@pytest.fixture -def create_llm_model_endpoint_text_generation_inference_request_streaming() -> ( - CreateLLMModelEndpointV1Request -): - return CreateLLMModelEndpointV1Request( - name="test_llm_endpoint_name_tgi_streaming", - model_name="mpt-7b", - source="hugging_face", - inference_framework="deepspeed", - inference_framework_image_tag="test_tag", - num_shards=2, - quantize=Quantization.BITSANDBYTES, - endpoint_type=ModelEndpointType.STREAMING, - metadata={}, - post_inference_hooks=[], - cpus=1, - gpus=2, - memory="8G", - gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, - min_workers=1, - max_workers=3, - per_worker=2, - labels={"team": "infra", "product": "my_product"}, - aws_role="test_aws_role", - results_s3_bucket="test_s3_bucket", - ) - - -@pytest.fixture -def create_llm_model_endpoint_text_generation_inference_request_async() -> ( - CreateLLMModelEndpointV1Request -): - return CreateLLMModelEndpointV1Request( - name="test_llm_endpoint_name_tgi_async", - model_name="mpt-7b", - source="hugging_face", - inference_framework="text_generation_inference", - inference_framework_image_tag="test_tag", - num_shards=2, - quantize=Quantization.BITSANDBYTES, - endpoint_type=ModelEndpointType.ASYNC, - metadata={}, - post_inference_hooks=[], - cpus=1, - gpus=2, - memory="8G", - gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, - min_workers=1, - max_workers=3, - per_worker=2, - labels={"team": "infra", "product": "my_product"}, - aws_role="test_aws_role", - results_s3_bucket="test_s3_bucket", - ) - - -@pytest.fixture -def create_llm_model_endpoint_request_invalid_model_name() -> CreateLLMModelEndpointV1Request: - return CreateLLMModelEndpointV1Request( - name="test_llm_endpoint_name_1", - model_name="nonexist", - source="hugging_face", - inference_framework="deepspeed", - inference_framework_image_tag="test_tag", - num_shards=2, - endpoint_type=ModelEndpointType.SYNC, - metadata={}, - post_inference_hooks=[], - cpus=1, - gpus=2, - memory="8G", - gpu_type=GpuType.NVIDIA_TESLA_T4, - storage=None, - min_workers=1, - max_workers=3, - per_worker=2, - labels={"team": "infra", "product": "my_product"}, - aws_role="test_aws_role", - results_s3_bucket="test_s3_bucket", - ) - - -@pytest.fixture -def completion_sync_request() -> CompletionSyncV1Request: - return CompletionSyncV1Request( - prompts=["test_prompt_1", "test_prompt_2"], - max_new_tokens=10, - temperature=0.5, - ) - - -@pytest.fixture -def completion_stream_request() -> CompletionStreamV1Request: - return CompletionStreamV1Request( - prompt="test_prompt_1", - max_new_tokens=10, - temperature=0.5, - ) diff --git a/server/tests/unit/domain/test_llm_use_cases.py b/server/tests/unit/domain/test_llm_use_cases.py deleted file mode 100644 index d33c4fb9..00000000 --- a/server/tests/unit/domain/test_llm_use_cases.py +++ /dev/null @@ -1,587 +0,0 @@ -from typing import Any, Tuple - -import pytest -from llm_engine_server.common.dtos.llms import ( - CompletionOutput, - CompletionStreamV1Request, - CompletionSyncV1Request, - CreateLLMModelEndpointV1Request, - CreateLLMModelEndpointV1Response, -) -from llm_engine_server.common.dtos.tasks import SyncEndpointPredictV1Response, TaskStatus -from llm_engine_server.core.auth.authentication_repository import User -from llm_engine_server.core.domain_exceptions import ( - ObjectHasInvalidValueException, - ObjectNotAuthorizedException, - ObjectNotFoundException, -) -from llm_engine_server.domain.entities import ModelEndpoint, ModelEndpointType -from llm_engine_server.domain.exceptions import EndpointUnsupportedInferenceTypeException -from llm_engine_server.domain.use_cases.llm_model_endpoint_use_cases import ( - CompletionStreamV1UseCase, - CompletionSyncV1UseCase, - CreateLLMModelEndpointV1UseCase, - GetLLMModelEndpointByNameV1UseCase, -) -from llm_engine_server.domain.use_cases.model_bundle_use_cases import CreateModelBundleV2UseCase - - -@pytest.mark.asyncio -async def test_create_model_endpoint_use_case_success( - test_api_key: str, - fake_model_bundle_repository, - fake_model_endpoint_service, - fake_docker_repository_image_always_exists, - fake_model_primitive_gateway, - create_llm_model_endpoint_request_async: CreateLLMModelEndpointV1Request, - create_llm_model_endpoint_request_sync: CreateLLMModelEndpointV1Request, - create_llm_model_endpoint_request_streaming: CreateLLMModelEndpointV1Request, -): - fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository - bundle_use_case = CreateModelBundleV2UseCase( - model_bundle_repository=fake_model_bundle_repository, - docker_repository=fake_docker_repository_image_always_exists, - model_primitive_gateway=fake_model_primitive_gateway, - ) - use_case = CreateLLMModelEndpointV1UseCase( - create_model_bundle_use_case=bundle_use_case, - model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, - ) - user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_async) - assert response_1.endpoint_creation_task_id - assert isinstance(response_1, CreateLLMModelEndpointV1Response) - endpoint = ( - await fake_model_endpoint_service.list_model_endpoints( - owner=None, - name=create_llm_model_endpoint_request_async.name, - order_by=None, - ) - )[0] - assert endpoint.record.endpoint_type == ModelEndpointType.ASYNC - assert endpoint.record.metadata == { - "_llm": { - "model_name": create_llm_model_endpoint_request_async.model_name, - "source": create_llm_model_endpoint_request_async.source, - "inference_framework": create_llm_model_endpoint_request_async.inference_framework, - "inference_framework_image_tag": create_llm_model_endpoint_request_async.inference_framework_image_tag, - "num_shards": create_llm_model_endpoint_request_async.num_shards, - "quantize": None, - } - } - - response_2 = await use_case.execute(user=user, request=create_llm_model_endpoint_request_sync) - assert response_2.endpoint_creation_task_id - assert isinstance(response_2, CreateLLMModelEndpointV1Response) - endpoint = ( - await fake_model_endpoint_service.list_model_endpoints( - owner=None, - name=create_llm_model_endpoint_request_sync.name, - order_by=None, - ) - )[0] - assert endpoint.record.endpoint_type == ModelEndpointType.SYNC - assert endpoint.record.metadata == { - "_llm": { - "model_name": create_llm_model_endpoint_request_sync.model_name, - "source": create_llm_model_endpoint_request_sync.source, - "inference_framework": create_llm_model_endpoint_request_sync.inference_framework, - "inference_framework_image_tag": create_llm_model_endpoint_request_sync.inference_framework_image_tag, - "num_shards": create_llm_model_endpoint_request_sync.num_shards, - "quantize": None, - } - } - - response_3 = await use_case.execute( - user=user, request=create_llm_model_endpoint_request_streaming - ) - assert response_3.endpoint_creation_task_id - assert isinstance(response_3, CreateLLMModelEndpointV1Response) - endpoint = ( - await fake_model_endpoint_service.list_model_endpoints( - owner=None, - name=create_llm_model_endpoint_request_streaming.name, - order_by=None, - ) - )[0] - assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING - assert endpoint.record.metadata == { - "_llm": { - "model_name": create_llm_model_endpoint_request_streaming.model_name, - "source": create_llm_model_endpoint_request_streaming.source, - "inference_framework": create_llm_model_endpoint_request_streaming.inference_framework, - "inference_framework_image_tag": create_llm_model_endpoint_request_streaming.inference_framework_image_tag, - "num_shards": create_llm_model_endpoint_request_streaming.num_shards, - "quantize": None, - } - } - - -@pytest.mark.asyncio -async def test_create_model_endpoint_text_generation_inference_use_case_success( - test_api_key: str, - fake_model_bundle_repository, - fake_model_endpoint_service, - fake_docker_repository_image_always_exists, - fake_model_primitive_gateway, - create_llm_model_endpoint_text_generation_inference_request_async: CreateLLMModelEndpointV1Request, - create_llm_model_endpoint_text_generation_inference_request_streaming: CreateLLMModelEndpointV1Request, -): - fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository - bundle_use_case = CreateModelBundleV2UseCase( - model_bundle_repository=fake_model_bundle_repository, - docker_repository=fake_docker_repository_image_always_exists, - model_primitive_gateway=fake_model_primitive_gateway, - ) - use_case = CreateLLMModelEndpointV1UseCase( - create_model_bundle_use_case=bundle_use_case, - model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, - ) - user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = await use_case.execute( - user=user, request=create_llm_model_endpoint_text_generation_inference_request_streaming - ) - assert response_1.endpoint_creation_task_id - assert isinstance(response_1, CreateLLMModelEndpointV1Response) - endpoint = ( - await fake_model_endpoint_service.list_model_endpoints( - owner=None, - name=create_llm_model_endpoint_text_generation_inference_request_streaming.name, - order_by=None, - ) - )[0] - assert endpoint.record.endpoint_type == ModelEndpointType.STREAMING - assert endpoint.record.metadata == { - "_llm": { - "model_name": create_llm_model_endpoint_text_generation_inference_request_streaming.model_name, - "source": create_llm_model_endpoint_text_generation_inference_request_streaming.source, - "inference_framework": create_llm_model_endpoint_text_generation_inference_request_streaming.inference_framework, - "inference_framework_image_tag": create_llm_model_endpoint_text_generation_inference_request_streaming.inference_framework_image_tag, - "num_shards": create_llm_model_endpoint_text_generation_inference_request_streaming.num_shards, - "quantize": create_llm_model_endpoint_text_generation_inference_request_streaming.quantize, - } - } - - with pytest.raises(ObjectHasInvalidValueException): - await use_case.execute( - user=user, request=create_llm_model_endpoint_text_generation_inference_request_async - ) - - -@pytest.mark.asyncio -async def test_create_llm_model_endpoint_use_case_raises_invalid_value_exception( - test_api_key: str, - fake_model_bundle_repository, - fake_model_endpoint_service, - fake_docker_repository_image_always_exists, - fake_model_primitive_gateway, - create_llm_model_endpoint_request_invalid_model_name: CreateLLMModelEndpointV1Request, -): - fake_model_endpoint_service.model_bundle_repository = fake_model_bundle_repository - bundle_use_case = CreateModelBundleV2UseCase( - model_bundle_repository=fake_model_bundle_repository, - docker_repository=fake_docker_repository_image_always_exists, - model_primitive_gateway=fake_model_primitive_gateway, - ) - use_case = CreateLLMModelEndpointV1UseCase( - create_model_bundle_use_case=bundle_use_case, - model_bundle_repository=fake_model_bundle_repository, - model_endpoint_service=fake_model_endpoint_service, - ) - user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - with pytest.raises(ObjectHasInvalidValueException): - await use_case.execute( - user=user, request=create_llm_model_endpoint_request_invalid_model_name - ) - - -@pytest.mark.asyncio -async def test_get_llm_model_endpoint_use_case_raises_not_found( - test_api_key: str, - fake_llm_model_endpoint_service, - llm_model_endpoint_async: Tuple[ModelEndpoint, Any], -): - fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_async[0]) - use_case = GetLLMModelEndpointByNameV1UseCase( - llm_model_endpoint_service=fake_llm_model_endpoint_service - ) - user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - with pytest.raises(ObjectNotFoundException): - await use_case.execute(user=user, model_endpoint_name="invalid_model_endpoint_name") - - -@pytest.mark.asyncio -async def test_get_llm_model_endpoint_use_case_raises_not_authorized( - test_api_key: str, - fake_llm_model_endpoint_service, - llm_model_endpoint_async: Tuple[ModelEndpoint, Any], -): - fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_async[0]) - use_case = GetLLMModelEndpointByNameV1UseCase( - llm_model_endpoint_service=fake_llm_model_endpoint_service - ) - llm_model_endpoint_async[0].record.public_inference = False - user = User(user_id="non_exist", team_id="non_exist", is_privileged_user=False) - with pytest.raises(ObjectNotAuthorizedException): - await use_case.execute( - user=user, model_endpoint_name=llm_model_endpoint_async[0].record.name - ) - - -@pytest.mark.asyncio -async def test_completion_sync_use_case_success( - test_api_key: str, - fake_model_endpoint_service, - fake_llm_model_endpoint_service, - llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], - completion_sync_request: CompletionSyncV1Request, -): - fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) - fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={ - "result": [ - { - "error": None, - "text": "I am a newbie to the world of programming.", - "token_probs": { - "tokens": [ - "I", - " am", - " a", - " new", - "bie", - " to", - " the", - " world", - " of", - " programming", - ".", - ] - }, - "tokens_consumed": 25, - } - ] - }, - traceback=None, - ) - ) - use_case = CompletionSyncV1UseCase( - model_endpoint_service=fake_model_endpoint_service, - llm_model_endpoint_service=fake_llm_model_endpoint_service, - ) - user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = await use_case.execute( - user=user, - model_endpoint_name=llm_model_endpoint_sync[0].record.name, - request=completion_sync_request, - ) - assert response_1.status == TaskStatus.SUCCESS - assert response_1.outputs == [ - CompletionOutput( - text="I am a newbie to the world of programming.", - num_completion_tokens=11, - ) - ] - - -@pytest.mark.asyncio -async def test_completion_sync_text_generation_inference_use_case_success( - test_api_key: str, - fake_model_endpoint_service, - fake_llm_model_endpoint_service, - llm_model_endpoint_text_generation_inference: ModelEndpoint, - completion_sync_request: CompletionSyncV1Request, -): - fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_text_generation_inference) - fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={ - "result": """ - { - "generated_text": " Deep Learning is a new type of machine learning", - "details": { - "finish_reason": "length", - "generated_tokens": 9, - "prefill": [ - { - "id": 10560, - "text": "What" - }, - { - "id": 632, - "text": " is" - }, - { - "id": 89554, - "text": " Deep" - }, - { - "id": 89950, - "text": " Learning" - }, - { - "id": 34, - "text": "?" - } - ], - "tokens": [ - { - "text": " Deep" - }, - { - "text": " Learning" - }, - { - "text": " is" - }, - { - "text": " a" - }, - { - "text": " new" - }, - { - "text": " type" - }, - { - "text": " of" - }, - { - "text": " machine" - }, - { - "text": " learning" - } - ] - } - } -""" - }, - traceback=None, - ) - use_case = CompletionSyncV1UseCase( - model_endpoint_service=fake_model_endpoint_service, - llm_model_endpoint_service=fake_llm_model_endpoint_service, - ) - user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = await use_case.execute( - user=user, - model_endpoint_name=llm_model_endpoint_text_generation_inference.record.name, - request=completion_sync_request, - ) - assert response_1.status == TaskStatus.SUCCESS - print(response_1.outputs) - assert response_1.outputs == [ - CompletionOutput( - text=" Deep Learning is a new type of machine learning", - num_completion_tokens=9, - ), - CompletionOutput( - text=" Deep Learning is a new type of machine learning", - num_completion_tokens=9, - ), - ] - - -@pytest.mark.asyncio -async def test_completion_sync_use_case_predict_failed( - test_api_key: str, - fake_model_endpoint_service, - fake_llm_model_endpoint_service, - llm_model_endpoint_sync: Tuple[ModelEndpoint, Any], - completion_sync_request: CompletionSyncV1Request, -): - fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_sync[0]) - fake_model_endpoint_service.sync_model_endpoint_inference_gateway.response = ( - SyncEndpointPredictV1Response( - status=TaskStatus.FAILURE, - result=None, - traceback="failed to predict", - ) - ) - use_case = CompletionSyncV1UseCase( - model_endpoint_service=fake_model_endpoint_service, - llm_model_endpoint_service=fake_llm_model_endpoint_service, - ) - user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = await use_case.execute( - user=user, - model_endpoint_name=llm_model_endpoint_sync[0].record.name, - request=completion_sync_request, - ) - assert response_1.status == TaskStatus.FAILURE - assert len(response_1.outputs) == 0 - assert response_1.traceback == "failed to predict" - - -@pytest.mark.asyncio -async def test_completion_sync_use_case_not_sync_endpoint_raises( - test_api_key: str, - fake_model_endpoint_service, - fake_llm_model_endpoint_service, - llm_model_endpoint_async: Tuple[ModelEndpoint, Any], - completion_sync_request: CompletionSyncV1Request, -): - fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_async[0]) - use_case = CompletionSyncV1UseCase( - model_endpoint_service=fake_model_endpoint_service, - llm_model_endpoint_service=fake_llm_model_endpoint_service, - ) - user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - with pytest.raises(EndpointUnsupportedInferenceTypeException): - await use_case.execute( - user=user, - model_endpoint_name=llm_model_endpoint_async[0].record.name, - request=completion_sync_request, - ) - - -@pytest.mark.asyncio -async def test_completion_stream_use_case_success( - test_api_key: str, - fake_model_endpoint_service, - fake_llm_model_endpoint_service, - llm_model_endpoint_streaming: ModelEndpoint, - completion_stream_request: CompletionStreamV1Request, -): - fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_streaming) - fake_model_endpoint_service.streaming_model_endpoint_inference_gateway.responses = [ - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={"result": {"token": "I"}}, - traceback=None, - ), - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={"result": {"token": " am"}}, - traceback=None, - ), - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={"result": {"token": " a"}}, - traceback=None, - ), - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={"result": {"token": " new"}}, - traceback=None, - ), - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={"result": {"token": "bie"}}, - traceback=None, - ), - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={"result": {"token": "."}}, - traceback=None, - ), - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={ - "result": { - "response": [ - { - "error": None, - "text": "I am a newbie.", - "token_probs": { - "tokens": [ - "I", - " am", - " a", - " new", - "bie", - ".", - ] - }, - "tokens_consumed": 25, - } - ] - } - }, - traceback=None, - ), - ] - use_case = CompletionStreamV1UseCase( - model_endpoint_service=fake_model_endpoint_service, - llm_model_endpoint_service=fake_llm_model_endpoint_service, - ) - user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = use_case.execute( - user=user, - model_endpoint_name=llm_model_endpoint_streaming.record.name, - request=completion_stream_request, - ) - output_texts = ["I", " am", " a", " new", "bie", ".", "I am a newbie."] - i = 0 - async for message in response_1: - assert message.dict()["status"] == "SUCCESS" - assert message.dict()["output"]["text"] == output_texts[i] - if i == 6: - assert message.dict()["output"]["num_completion_tokens"] == 6 - i += 1 - - -@pytest.mark.asyncio -async def test_completion_stream_text_generation_inference_use_case_success( - test_api_key: str, - fake_model_endpoint_service, - fake_llm_model_endpoint_service, - llm_model_endpoint_text_generation_inference: ModelEndpoint, - completion_stream_request: CompletionStreamV1Request, -): - fake_llm_model_endpoint_service.add_model_endpoint(llm_model_endpoint_text_generation_inference) - fake_model_endpoint_service.streaming_model_endpoint_inference_gateway.responses = [ - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={"result": {"token": {"text": "I"}}}, - traceback=None, - ), - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={"result": {"token": {"text": " am"}}}, - traceback=None, - ), - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={"result": {"token": {"text": " a"}}}, - traceback=None, - ), - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={"result": {"token": {"text": " new"}}}, - traceback=None, - ), - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={"result": {"token": {"text": "bie"}}}, - traceback=None, - ), - SyncEndpointPredictV1Response( - status=TaskStatus.SUCCESS, - result={"result": {"token": {"text": "."}, "generated_text": "I am a newbie."}}, - traceback=None, - ), - ] - use_case = CompletionStreamV1UseCase( - model_endpoint_service=fake_model_endpoint_service, - llm_model_endpoint_service=fake_llm_model_endpoint_service, - ) - user = User(user_id=test_api_key, team_id=test_api_key, is_privileged_user=True) - response_1 = use_case.execute( - user=user, - model_endpoint_name=llm_model_endpoint_text_generation_inference.record.name, - request=completion_stream_request, - ) - output_texts = ["I", " am", " a", " new", "bie", ".", "I am a newbie."] - i = 0 - async for message in response_1: - assert message.dict()["status"] == "SUCCESS" - assert message.dict()["output"]["text"] == output_texts[i] - if i == 5: - assert message.dict()["output"]["num_completion_tokens"] == 6 - i += 1 diff --git a/server/tests/unit/inference/test_forwarding.py b/server/tests/unit/inference/test_forwarding.py deleted file mode 100644 index 5f343fcd..00000000 --- a/server/tests/unit/inference/test_forwarding.py +++ /dev/null @@ -1,271 +0,0 @@ -import json -from dataclasses import dataclass -from typing import Mapping -from unittest import mock - -import pytest -from llm_engine_server.core.utils.env import environment -from llm_engine_server.domain.entities import ModelEndpointConfig -from llm_engine_server.inference.forwarding.forwarding import ( - ENV_SERIALIZE_RESULTS_AS_STRING, - KEY_SERIALIZE_RESULTS_AS_STRING, - Forwarder, - LoadForwarder, - LoadStreamingForwarder, - StreamingForwarder, -) -from llm_engine_server.inference.infra.gateways.datadog_inference_monitoring_metrics_gateway import ( - DatadogInferenceMonitoringMetricsGateway, -) -from llm_engine_server.inference.post_inference_hooks import PostInferenceHooksHandler - -PAYLOAD: Mapping[str, Mapping[str, str]] = {"hello": "world"} - - -def mocked_get(*args, **kwargs): # noqa - @dataclass - class mocked_static_status_code: - status_code: int = 200 - - return mocked_static_status_code() - - -def mocked_post(*args, **kwargs): # noqa - @dataclass - class mocked_static_json: - def json(self) -> dict: - return PAYLOAD # type: ignore - - return mocked_static_json() - - -def mocked_sse_client(*args, **kwargs): # noqa - @dataclass - class Event: - data: str - - @dataclass - class mocked_static_events: - def events(self) -> list: - payload_json = json.dumps(PAYLOAD) - return [Event(data=payload_json), Event(data=payload_json)] - - return mocked_static_events() - - -def mocked_get_endpoint_config(): - return ModelEndpointConfig( - endpoint_name="test_endpoint_name", - bundle_name="test_bundle_name", - ) - - -@pytest.fixture -def post_inference_hooks_handler(): - handler = PostInferenceHooksHandler( - endpoint_name="test_endpoint_name", - bundle_name="test_bundle_name", - post_inference_hooks=[], - user_id="test_user_id", - default_callback_url=None, - default_callback_auth=None, - monitoring_metrics_gateway=DatadogInferenceMonitoringMetricsGateway(), - ) - return handler - - -@mock.patch("requests.post", mocked_post) -@mock.patch("requests.get", mocked_get) -def test_forwarders(post_inference_hooks_handler): - fwd = Forwarder( - "ignored", - llm_engine_unwrap=True, - serialize_results_as_string=False, - post_inference_hooks_handler=post_inference_hooks_handler, - wrap_response=True, - ) - json_response = fwd({"args": {"ignore": "me"}}) - _check(json_response) - - -def _check(json_response) -> None: - assert json_response == {"result": PAYLOAD} - - -def _check_responses_not_wrapped(json_response) -> None: - assert json_response == PAYLOAD - - -def _check_streaming(streaming_response) -> None: - streaming_response_list = list(streaming_response) - assert len(streaming_response_list) == 2 - assert streaming_response_list[0] == {"result": PAYLOAD} - assert streaming_response_list[1] == {"result": PAYLOAD} - - -def _check_streaming_serialized(streaming_response) -> None: - streaming_response_list = list(streaming_response) - assert len(streaming_response_list) == 2 - assert streaming_response_list[0] == {"result": json.dumps(PAYLOAD)} - assert streaming_response_list[1] == {"result": json.dumps(PAYLOAD)} - - -@mock.patch("requests.post", mocked_post) -@mock.patch("requests.get", mocked_get) -def test_forwarders_serialize_results_as_string(post_inference_hooks_handler): - fwd = Forwarder( - "ignored", - llm_engine_unwrap=True, - serialize_results_as_string=True, - post_inference_hooks_handler=post_inference_hooks_handler, - wrap_response=True, - ) - json_response = fwd({"args": {"ignore": "me"}}) - _check_serialized(json_response) - - -def _check_serialized(json_response) -> None: - assert isinstance(json_response["result"], str) - assert len(json_response) == 1, f"expecting only 'result' key, but got {json_response=}" - assert json.loads(json_response["result"]) == PAYLOAD - - -@mock.patch("requests.post", mocked_post) -@mock.patch("requests.get", mocked_get) -def test_forwarders_override_serialize_results(post_inference_hooks_handler): - fwd = Forwarder( - "ignored", - llm_engine_unwrap=True, - serialize_results_as_string=True, - post_inference_hooks_handler=post_inference_hooks_handler, - wrap_response=True, - ) - json_response = fwd({"args": {"ignore": "me", KEY_SERIALIZE_RESULTS_AS_STRING: False}}) - _check(json_response) - assert json_response == {"result": PAYLOAD} - - fwd = Forwarder( - "ignored", - llm_engine_unwrap=True, - serialize_results_as_string=False, - post_inference_hooks_handler=post_inference_hooks_handler, - wrap_response=True, - ) - json_response = fwd({"args": {"ignore": "me", KEY_SERIALIZE_RESULTS_AS_STRING: True}}) - _check_serialized(json_response) - - -@mock.patch("requests.post", mocked_post) -@mock.patch("requests.get", mocked_get) -def test_forwarder_does_not_wrap_response(post_inference_hooks_handler): - fwd = Forwarder( - "ignored", - llm_engine_unwrap=True, - serialize_results_as_string=False, - post_inference_hooks_handler=post_inference_hooks_handler, - wrap_response=False, - ) - json_response = fwd({"args": {"ignore": "me"}}) - _check_responses_not_wrapped(json_response) - - -@mock.patch("requests.post", mocked_post) -@mock.patch("requests.get", mocked_get) -@mock.patch( - "llm_engine_server.inference.forwarding.forwarding.get_endpoint_config", - mocked_get_endpoint_config, -) -def test_forwarder_loader(): - fwd = LoadForwarder(serialize_results_as_string=True).load(None, None) # type: ignore - json_response = fwd({"args": {"ignore": "me"}}) - _check_serialized(json_response) - - fwd = LoadForwarder(serialize_results_as_string=False).load(None, None) # type: ignore - json_response = fwd({"args": {"ignore": "me"}}) - _check(json_response) - - fwd = LoadForwarder(wrap_response=False).load(None, None) # type: ignore - json_response = fwd({"args": {"ignore": "me"}}) - _check_responses_not_wrapped(json_response) - - -@mock.patch("requests.post", mocked_post) -@mock.patch("requests.get", mocked_get) -@mock.patch( - "llm_engine_server.inference.forwarding.forwarding.get_endpoint_config", - mocked_get_endpoint_config, -) -def test_forwarder_loader_env_serialize_behavior(post_inference_hooks_handler): - with environment(**{ENV_SERIALIZE_RESULTS_AS_STRING: "false"}): - fwd = LoadForwarder(serialize_results_as_string=True).load(None, None) # type: ignore - json_response = fwd({"args": {"ignore": "me"}}) - _check(json_response) - - with environment(**{ENV_SERIALIZE_RESULTS_AS_STRING: "true"}): - fwd = LoadForwarder(serialize_results_as_string=False).load(None, None) # type: ignore - json_response = fwd({"args": {"ignore": "me"}}) - _check_serialized(json_response) - - -@mock.patch("requests.post", mocked_post) -@mock.patch("requests.get", mocked_get) -def test_forwarder_serialize_within_args(post_inference_hooks_handler): - # standard Spellbook-Serve-created forwarder - fwd = Forwarder( - "ignored", - llm_engine_unwrap=True, - serialize_results_as_string=True, - post_inference_hooks_handler=post_inference_hooks_handler, - wrap_response=True, - ) - # expected: no `serialize_results_as_string` at top-level nor in 'args' - json_response = fwd({"something": "to ignore", "args": {"my": "payload", "is": "here"}}) - _check_serialized(json_response) - # unwraps under "args" to find `serialize_results_as_string` - payload = { - "something": "to ignore", - "args": {"my": "payload", "is": "here", "serialize_results_as_string": False}, - } - json_response = fwd(payload) - _check(json_response) - # w/o unwrapping it won't "find" the `"serialize_results_as_string": False` directive - fwd = Forwarder( - "ignored", - llm_engine_unwrap=False, - serialize_results_as_string=True, - post_inference_hooks_handler=post_inference_hooks_handler, - wrap_response=True, - ) - json_response = fwd(payload) - _check_serialized(json_response) - - -@mock.patch("requests.post", mocked_post) -@mock.patch("requests.get", mocked_get) -@mock.patch("sseclient.SSEClient", mocked_sse_client) -def test_streaming_forwarders(post_inference_hooks_handler): - fwd = StreamingForwarder( - "ignored", - llm_engine_unwrap=True, - serialize_results_as_string=False, - post_inference_hooks_handler=post_inference_hooks_handler, - ) - response = fwd({"args": {"ignore": "me"}}) - _check_streaming(response) - - -@mock.patch("requests.post", mocked_post) -@mock.patch("requests.get", mocked_get) -@mock.patch("sseclient.SSEClient", mocked_sse_client) -@mock.patch( - "llm_engine_server.inference.forwarding.forwarding.get_endpoint_config", - mocked_get_endpoint_config, -) -def test_streaming_forwarder_loader(): - fwd = LoadStreamingForwarder(serialize_results_as_string=True).load(None, None) # type: ignore - json_response = fwd({"args": {"ignore": "me"}}) - _check_streaming_serialized(json_response) - - fwd = LoadStreamingForwarder(serialize_results_as_string=False).load(None, None) # type: ignore - response = fwd({"args": {"ignore": "me"}}) - _check_streaming(response) diff --git a/server/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py b/server/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py deleted file mode 100644 index e1652ceb..00000000 --- a/server/tests/unit/infra/gateways/test_live_docker_image_batch_job_gateway.py +++ /dev/null @@ -1,44 +0,0 @@ -from llm_engine_server.infra.gateways.live_docker_image_batch_job_gateway import ( - K8sEnvDict, - _add_list_values, - _check_batch_job_id_valid, - _get_job_id, -) - - -def test_valid_job_ids_are_valid(): - for _ in range(20): - # _get_job_id() is nondeterministic - job_id = _get_job_id() - assert _check_batch_job_id_valid(job_id), f"job_id {job_id} apparently isn't valid" - - -def test_invalid_job_ids_are_invalid(): - assert not _check_batch_job_id_valid("spaces fail") - assert not _check_batch_job_id_valid("punctuation'") - assert not _check_batch_job_id_valid(".") - - -# test the adding list values -def test_add_list_values(): - default_values = [ - K8sEnvDict(name="default1", value="val1"), - K8sEnvDict(name="default2", value="val2"), - K8sEnvDict(name="default3", value="val3"), - ] - override_values = [ - K8sEnvDict(name="default1", value="override0"), - K8sEnvDict(name="override1", value="override1"), - K8sEnvDict(name="override2", value="override2"), - ] - expected_values = [ - K8sEnvDict(name="default1", value="val1"), - K8sEnvDict(name="default2", value="val2"), - K8sEnvDict(name="default3", value="val3"), - K8sEnvDict(name="override1", value="override1"), - K8sEnvDict(name="override2", value="override2"), - ] - - actual_values = _add_list_values(default_values, override_values) - actual_values.sort(key=lambda x: x["name"]) - assert expected_values == actual_values diff --git a/server/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py b/server/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py deleted file mode 100644 index 58a735ee..00000000 --- a/server/tests/unit/infra/gateways/test_live_streaming_model_endpoint_inference_gateway.py +++ /dev/null @@ -1,185 +0,0 @@ -import json -from dataclasses import dataclass -from typing import Any, Dict, Tuple -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from llm_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, - SyncEndpointPredictV1Response, -) -from llm_engine_server.domain.exceptions import UpstreamServiceError -from llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway import ( - LiveStreamingModelEndpointInferenceGateway, -) - - -@dataclass -class FakeIterator: - content: bytes = b'{"test": "content"}' - count: int = 0 - - def __aiter__(self): - return self - - async def __anext__(self): - self.count = self.count + 1 - if self.count == 1: - return b"data: " + self.content - if self.count in {2, 3}: - return b"\n" - if self.count == 4: - raise StopAsyncIteration - - -@dataclass -class FakeResponse: - def __init__(self, status: int, message_content: bytes = b'{"test": "content"}'): - self.status = status - self.message_content = message_content - self.content = FakeIterator(content=message_content) - - async def read(self): - return self.message_content - - -def _get_mock_client_session(fake_response: FakeResponse): - mock_post = AsyncMock(return_value=fake_response) - mock_client_session_val = AsyncMock() - mock_client_session_val.post = mock_post - mock_client_session_val.__aenter__ = AsyncMock(return_value=mock_client_session_val) - mock_client_session_val.__aexit__ = AsyncMock() - mock_client_session = MagicMock(return_value=mock_client_session_val) - return mock_client_session - - -@pytest.mark.asyncio -async def test_make_request_with_retries_success(): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) - - fake_response = FakeResponse(status=200) - mock_client_session = _get_mock_client_session(fake_response) - - with patch( - "llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", - mock_client_session, - ): - response = gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) - count = 0 - async for message in response: - assert message == {"test": "content"} - count += 1 - assert count == 1 - - -@pytest.mark.asyncio -async def test_make_request_with_retries_failed_429(): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) - - fake_response = FakeResponse(status=429) - mock_client_session = _get_mock_client_session(fake_response) - - with pytest.raises(UpstreamServiceError), patch( - "llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", - mock_client_session, - ): - async for response in gateway.make_request_with_retries("test_request_url", {}, 0.05, 2): - response - - -@pytest.mark.asyncio -async def test_make_request_with_retries_failed_traceback(): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) - - fake_response = FakeResponse(status=500) - mock_client_session = _get_mock_client_session(fake_response) - - with pytest.raises(UpstreamServiceError), patch( - "llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", - mock_client_session, - ): - async for response in gateway.make_request_with_retries("test_request_url", {}, 0.05, 2): - response - - -@pytest.mark.asyncio -async def test_streaming_predict_success( - endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]] -): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) - - fake_response = FakeResponse(status=200) - mock_client_session = _get_mock_client_session(fake_response) - with patch( - "llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", - mock_client_session, - ): - response = gateway.streaming_predict( - topic="test_topic", predict_request=endpoint_predict_request_1[0] - ) - count = 0 - async for message in response: - assert isinstance(message, SyncEndpointPredictV1Response) - assert message.dict() == { - "status": "SUCCESS", - "result": {"test": "content"}, - "traceback": None, - } - count += 1 - assert count == 1 - - -@pytest.mark.asyncio -async def test_predict_raises_traceback_json( - endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]] -): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) - - content = json.dumps({"detail": {"traceback": "test_traceback"}}).encode("utf-8") - fake_response = FakeResponse(status=500, message_content=content) - mock_client_session = _get_mock_client_session(fake_response) - with patch( - "llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", - mock_client_session, - ): - response = gateway.streaming_predict( - topic="test_topic", predict_request=endpoint_predict_request_1[0] - ) - count = 0 - async for message in response: - assert isinstance(message, SyncEndpointPredictV1Response) - assert message.dict() == { - "status": "FAILURE", - "result": None, - "traceback": "test_traceback", - } - count += 1 - assert count == 1 - - -@pytest.mark.asyncio -async def test_predict_raises_traceback_not_json( - endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]] -): - gateway = LiveStreamingModelEndpointInferenceGateway(use_asyncio=True) - - content = b"Test traceback content" - fake_response = FakeResponse(status=500, message_content=content) - mock_client_session = _get_mock_client_session(fake_response) - with patch( - "llm_engine_server.infra.gateways.live_streaming_model_endpoint_inference_gateway.aiohttp.ClientSession", - mock_client_session, - ): - response = gateway.streaming_predict( - topic="test_topic", predict_request=endpoint_predict_request_1[0] - ) - count = 0 - async for message in response: - assert isinstance(message, SyncEndpointPredictV1Response) - assert message.dict() == { - "status": "FAILURE", - "result": None, - "traceback": "Test traceback content", - } - count += 1 - assert count == 1 diff --git a/server/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py b/server/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py deleted file mode 100644 index d89ae7b7..00000000 --- a/server/tests/unit/infra/gateways/test_live_sync_model_endpoint_inference_gateway.py +++ /dev/null @@ -1,151 +0,0 @@ -import json -from dataclasses import dataclass -from typing import Any, Dict, Tuple -from unittest.mock import AsyncMock, MagicMock, patch - -import pytest -from llm_engine_server.common.dtos.tasks import ( - EndpointPredictV1Request, - SyncEndpointPredictV1Response, -) -from llm_engine_server.domain.exceptions import UpstreamServiceError -from llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway import ( - LiveSyncModelEndpointInferenceGateway, -) - - -@dataclass -class FakeResponse: - status: int - content: bytes = b"test_content" - body: Any = None - - async def read(self): - return self.content - - async def json(self): - return self.body if self.body else {"test_key": "test_value"} - - -def _get_mock_client_session(fake_response: FakeResponse): - mock_post = AsyncMock(return_value=fake_response) - mock_client_session_val = AsyncMock() - mock_client_session_val.post = mock_post - mock_client_session_val.__aenter__ = AsyncMock(return_value=mock_client_session_val) - mock_client_session_val.__aexit__ = AsyncMock() - mock_client_session = MagicMock(return_value=mock_client_session_val) - return mock_client_session - - -@pytest.mark.asyncio -async def test_make_request_with_retries_success(): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) - - fake_response = FakeResponse(status=200) - mock_client_session = _get_mock_client_session(fake_response) - - with patch( - "llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", - mock_client_session, - ): - response = await gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) - assert response == {"test_key": "test_value"} - - -@pytest.mark.asyncio -async def test_make_request_with_retries_failed_429(): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) - - fake_response = FakeResponse(status=429) - mock_client_session = _get_mock_client_session(fake_response) - - with pytest.raises(UpstreamServiceError), patch( - "llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", - mock_client_session, - ): - await gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) - - -@pytest.mark.asyncio -async def test_make_request_with_retries_failed_traceback(): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) - - fake_response = FakeResponse(status=500) - mock_client_session = _get_mock_client_session(fake_response) - - with pytest.raises(UpstreamServiceError), patch( - "llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", - mock_client_session, - ): - await gateway.make_request_with_retries("test_request_url", {}, 0.05, 2) - - -@pytest.mark.asyncio -async def test_predict_success( - endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]] -): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) - - fake_response = FakeResponse(status=200, body={"test_key": "test_value"}) - mock_client_session = _get_mock_client_session(fake_response) - with patch( - "llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", - mock_client_session, - ): - response = await gateway.predict( - topic="test_topic", predict_request=endpoint_predict_request_1[0] - ) - assert isinstance(response, SyncEndpointPredictV1Response) - assert response.dict() == { - "status": "SUCCESS", - "result": {"test_key": "test_value"}, - "traceback": None, - } - - -@pytest.mark.asyncio -async def test_predict_raises_traceback_json( - endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]] -): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) - - content = json.dumps({"detail": {"traceback": "test_traceback"}}).encode("utf-8") - fake_response = FakeResponse(status=500, content=content) - mock_client_session = _get_mock_client_session(fake_response) - with patch( - "llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", - mock_client_session, - ): - response = await gateway.predict( - topic="test_topic", predict_request=endpoint_predict_request_1[0] - ) - assert isinstance(response, SyncEndpointPredictV1Response) - assert response.dict() == { - "status": "FAILURE", - "result": None, - "traceback": "test_traceback", - } - - -@pytest.mark.asyncio -async def test_predict_raises_traceback_not_json( - endpoint_predict_request_1: Tuple[EndpointPredictV1Request, Dict[str, Any]] -): - gateway = LiveSyncModelEndpointInferenceGateway(use_asyncio=True) - - content = b"Test traceback content" - fake_response = FakeResponse(status=500, content=content) - mock_client_session = _get_mock_client_session(fake_response) - with patch( - "llm_engine_server.infra.gateways.live_sync_model_endpoint_inference_gateway.aiohttp.ClientSession", - mock_client_session, - ): - response = await gateway.predict( - topic="test_topic", predict_request=endpoint_predict_request_1[0] - ) - assert isinstance(response, SyncEndpointPredictV1Response) - assert response.dict() == { - "status": "FAILURE", - "result": None, - "traceback": "Test traceback content", - } diff --git a/server/tests/unit/infra/services/test_image_cache_service.py b/server/tests/unit/infra/services/test_image_cache_service.py deleted file mode 100644 index 3e04a89f..00000000 --- a/server/tests/unit/infra/services/test_image_cache_service.py +++ /dev/null @@ -1,38 +0,0 @@ -from typing import Any - -import pytest -from llm_engine_server.infra.services.image_cache_service import ImageCacheService - - -@pytest.mark.asyncio -async def test_image_cache_success( - fake_image_cache_service: ImageCacheService, - model_endpoint_1, - model_endpoint_2, - model_endpoint_3, - model_endpoint_4, -): - infra_states = { - model_endpoint_1.record.id: (bool, model_endpoint_1.infra_state), - model_endpoint_2.record.id: (bool, model_endpoint_2.infra_state), - model_endpoint_3.record.id: (bool, model_endpoint_3.infra_state), - model_endpoint_4.record.id: (bool, model_endpoint_4.infra_state), - } - repo: Any = fake_image_cache_service.model_endpoint_record_repository - repo.add_model_endpoint_record(model_endpoint_1.record) - repo.add_model_endpoint_record(model_endpoint_2.record) - repo.add_model_endpoint_record(model_endpoint_3.record) - repo.add_model_endpoint_record(model_endpoint_4.record) - - await fake_image_cache_service.execute(infra_states) # type: ignore - gateway: Any = fake_image_cache_service.image_cache_gateway - assert gateway.cached_images == { - "a10": [], - "a100": [], - "cpu": [], - "t4": [ - "000000000000.dkr.ecr.us-west-2.amazonaws.com/catalog-gpu:40d3b5fb06d1a8c3d14903390a3b23ae388bdb19", - "000000000000.dkr.ecr.us-west-2.amazonaws.com/catalog-gpu:e4ea48ddccfb9ca3ef6d846ae9b2d146d7e30b0f", - "000000000000.dkr.ecr.us-west-2.amazonaws.com/catalog-gpu:9a319cd9b897f02291f3242b1395f2b669993cdf-fd", - ], - }