Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support run tokens #2105

Draft
wants to merge 13 commits into
base: main
Choose a base branch
from
30 changes: 24 additions & 6 deletions .github/workflows/ci.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,9 @@ jobs:
- test-go
- test-python
- test-integration
if: startsWith(github.ref, 'refs/tags/')
permissions:
contents: write
contents: read
id-token: write
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
Expand All @@ -169,9 +169,27 @@ jobs:
- uses: actions/setup-go@v5
with:
go-version-file: go.mod
- uses: goreleaser/goreleaser-action@v6
- id: build
uses: goreleaser/goreleaser-action@v6
with:
version: '~> v2'
args: release --clean
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
args: build --clean --snapshot --id cog
- name: Authenticate to Google Cloud
uses: google-github-actions/auth@v2
with:
workload_identity_provider: 'projects/1025538909507/locations/global/workloadIdentityPools/github/providers/github-actions'
service_account: 'pipelines-beta-publish@replicate-production.iam.gserviceaccount.com'
- name: Upload release artifacts
uses: google-github-actions/upload-cloud-storage@v2
with:
path: dist/go
destination: replicate-pipelines-beta/releases/${{ fromJSON(steps.build.outputs.metadata).version }}
parent: false
predefinedAcl: publicRead
- name: Upload release artifacts (latest)
uses: google-github-actions/upload-cloud-storage@v2
with:
path: dist/go
destination: replicate-pipelines-beta/releases/latest
parent: false
predefinedAcl: publicRead
64 changes: 54 additions & 10 deletions pkg/cli/predict.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,9 @@ import (
"encoding/json"
"errors"
"fmt"
"io"
"net/http"
"net/url"
"os"
"os/signal"
"path/filepath"
Expand Down Expand Up @@ -341,24 +344,65 @@ func writeOutput(outputPath string, output []byte) error {
}

func writeDataURLOutput(outputString string, outputPath string, addExtension bool) error {
dataurlObj, err := dataurl.DecodeString(outputString)
if err != nil {
return fmt.Errorf("Failed to decode dataurl: %w", err)
var output []byte
var contentType string

if httpURL, ok := getHTTPURL(outputString); ok {
resp, err := http.Get(httpURL.String())
if err != nil {
return fmt.Errorf("Failed to fetch URL: %w", err)
}
defer resp.Body.Close()

output, err = io.ReadAll(resp.Body)
if err != nil {
return fmt.Errorf("Failed to read response: %w", err)
}
contentType = resp.Header.Get("Content-Type")
contentType = useExtensionIfUnknownContentType(contentType, output, outputString)

} else {
dataurlObj, err := dataurl.DecodeString(outputString)
if err != nil {
return fmt.Errorf("Failed to decode dataurl: %w", err)
}
output = dataurlObj.Data
contentType = dataurlObj.ContentType()
}
output := dataurlObj.Data

if addExtension {
extension := mime.ExtensionByType(dataurlObj.ContentType())
if extension != "" {
outputPath += extension
if ext := mime.ExtensionByType(contentType); ext != "" {
outputPath += ext
}
}

if err := writeOutput(outputPath, output); err != nil {
return err
return writeOutput(outputPath, output)
}

func getHTTPURL(str string) (*url.URL, bool) {
u, err := url.Parse(str)
if err == nil && (u.Scheme == "http" || u.Scheme == "https") {
return u, true
}
return nil, false
}

return nil
func useExtensionIfUnknownContentType(contentType string, content []byte, filename string) string {
// If contentType is empty or application/octet-string, first attempt to get the
// content type from the file extension, and if that fails, try to guess it from
// the content itself.

if contentType == "" || contentType == "application/octet-stream" {
if ext := filepath.Ext(filename); ext != "" {
if mimeType := mime.TypeByExtension(ext); mimeType != "" {
return mimeType
}
}
if detected := http.DetectContentType(content); detected != "" {
return detected
}
}
return contentType
}

func parseInputFlags(inputs []string) (predict.Inputs, error) {
Expand Down
3 changes: 3 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,9 @@ dependencies = [
"structlog>=20,<25",
"typing_extensions>=4.4.0",
"uvicorn[standard]>=0.12,<1",

# TODO(andreas): re-implement replicate functionality in pure python
"replicate>=1.0.4",
]

dynamic = ["version"]
Expand Down
2 changes: 2 additions & 0 deletions python/cog/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from pydantic import BaseModel

from .base_predictor import BasePredictor
from .include import include
from .mimetypes_ext import install_mime_extensions
from .server.scope import current_scope, emit_metric
from .types import (
Expand Down Expand Up @@ -36,4 +37,5 @@
"Input",
"Path",
"Secret",
"include",
]
4 changes: 4 additions & 0 deletions python/cog/command/ast_openapi_schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@
"title": "Output File Prefix",
"type": "string"
},
"run_token": {
"title": "Run Token",
"type": "string"
},
"webhook": {
"format": "uri",
"maxLength": 65536,
Expand Down
99 changes: 99 additions & 0 deletions python/cog/include.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
import os
import sys
from dataclasses import dataclass
from typing import Any, Callable, Dict, Optional, Tuple

import replicate
from replicate.exceptions import ModelError
from replicate.model import Model
from replicate.prediction import Prediction
from replicate.run import _has_output_iterator_array_type
from replicate.version import Version

from cog.server.scope import current_scope


def _find_api_token() -> str:
token = os.environ.get("REPLICATE_API_TOKEN")
if token:
print("Using Replicate API token from environment", file=sys.stderr)
return token

token = current_scope()._run_token

if not token:
raise ValueError("No run token found")

return token


@dataclass
class Run:
prediction: Prediction
version: Version

def wait(self) -> Any:
self.prediction.wait()

if self.prediction.status == "failed":
raise ModelError(self.prediction)

if _has_output_iterator_array_type(self.version):
return "".join(self.prediction.output)

return self.prediction.output

def logs(self) -> Optional[str]:
self.prediction.reload()

return self.prediction.logs


@dataclass
class Function:
function_ref: str

def _client(self) -> replicate.Client:
return replicate.Client(api_token=_find_api_token())

def _split_function_ref(self) -> Tuple[str, str, Optional[str]]:
owner, name = self.function_ref.split("/")
name, version = name.split(":") if ":" in name else (name, None)
return owner, name, version

def _model(self) -> Model:
client = self._client()
model_owner, model_name, _ = self._split_function_ref()
return client.models.get(f"{model_owner}/{model_name}")

def _version(self) -> Version:
client = self._client()
model_owner, model_name, model_version = self._split_function_ref()
model = client.models.get(f"{model_owner}/{model_name}")
version = (
model.versions.get(model_version) if model_version else model.latest_version
)
return version

def __call__(self, **inputs: Dict[str, Any]) -> Any:
run = self.start(**inputs)
return run.wait()

def start(self, **inputs: Dict[str, Any]) -> Run:
version = self._version()
prediction = self._client().predictions.create(version=version, input=inputs)
print(f"Running {self.function_ref}: https://replicate.com/p/{prediction.id}")

return Run(prediction, version)

@property
def default_example(self) -> Optional[Prediction]:
return self._model().default_example

@property
def openapi_schema(self) -> dict[Any, Any]:
return self._version().openapi_schema


def include(function_ref: str) -> Callable[..., Any]:
return Function(function_ref)
2 changes: 2 additions & 0 deletions python/cog/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class PredictionRequest(PredictionBaseModel):
default=WebhookEvent.default_events(),
)

run_token: Optional[str] = None

@classmethod
def with_types(cls, input_type: Type[Any]) -> Any:
# [compat] Input is implicitly optional -- previous versions of the
Expand Down
1 change: 1 addition & 0 deletions python/cog/server/eventtypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,3 +69,4 @@ class Envelope:
Done,
]
tag: Optional[str] = None
run_token: Optional[str] = None
4 changes: 3 additions & 1 deletion python/cog/server/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,9 @@ def predict(
payload = prediction.input.copy()

sid = self._worker.subscribe(task.handle_event, tag=tag)
task.track(self._worker.predict(payload, tag=tag))
task.track(
self._worker.predict(payload, tag=tag, run_token=prediction.run_token)
)
task.add_done_callback(self._task_done_callback(tag, sid))

return task
Expand Down
2 changes: 2 additions & 0 deletions python/cog/server/scope.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
@frozen
class Scope:
record_metric: Callable[[str, Union[float, int]], None]

_run_token: Optional[str] = None
_tag: Optional[str] = None


Expand Down
Loading
Loading