Skip to content
Permalink

Comparing changes

This is a direct comparison between two commits made in this repository or its related repositories. View the default comparison for this range or learn more about diff comparisons.

Open a pull request

Create a new pull request by comparing changes across two branches. If you need to, you can also . Learn more about diff comparisons here.
base repository: modal-labs/modal-client
Failed to load repositories. Confirm that selected base ref is valid, then try again.
Loading
base: 3f30fd324eefc010341239694904d756e6a79384
Choose a base ref
..
head repository: modal-labs/modal-client
Failed to load repositories. Confirm that selected head ref is valid, then try again.
Loading
compare: c1b15cbcc781eb0bddf4e9d4d5c6d65cfa3a91a3
Choose a head ref
23 changes: 23 additions & 0 deletions client_test/cls_test.py
Original file line number Diff line number Diff line change
@@ -302,3 +302,26 @@ def f(self, y):
def test_local_enter():
obj = ClsWithEnter(get_thread_id())
assert obj.f(10) == 420


inheritance_stub = Stub()


class BaseCls:
def __enter__(self):
self.x = 2

@method()
def run(self, y):
return self.x * y


@inheritance_stub.cls()
class DerivedCls(BaseCls):
pass


def test_derived_cls(client, servicer):
with inheritance_stub.run(client=client):
# default servicer fn just squares the number
assert DerivedCls.run.remote(3) == 9
10 changes: 10 additions & 0 deletions client_test/container_test.py
Original file line number Diff line number Diff line change
@@ -752,3 +752,13 @@ def test_param_cls_function_calling_local(unix_servicer, event_loop):
assert len(items) == 1
assert items[0].result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
assert items[0].result.data == serialize("111 foo 42")


@skip_windows_unix_socket
def test_derived_cls(unix_servicer, event_loop):
client, items = _run_container(
unix_servicer, "modal_test_support.functions", "DerivedCls.run", inputs=_get_inputs(((3,), {}))
)
assert len(items) == 1
assert items[0].result.status == api_pb2.GenericResult.GENERIC_STATUS_SUCCESS
assert items[0].result.data == serialize(6)
2 changes: 1 addition & 1 deletion client_test/token_flow_test.py
Original file line number Diff line number Diff line change
@@ -9,7 +9,7 @@
@pytest.mark.asyncio
async def test_token_flow_server(servicer, client):
tf = TokenFlow(client)
async with tf.start() as (token_flow_id, _):
async with tf.start() as (token_flow_id, _, _):
# Make a request against the local web server and make sure it validates
localhost_url = f"http://localhost:{servicer.token_flow_localhost_port}"
async with aiohttp.ClientSession() as session:
13 changes: 11 additions & 2 deletions modal/_function_utils.py
Original file line number Diff line number Diff line change
@@ -86,8 +86,17 @@ class FunctionInfo:
def __init__(self, f, serialized=False, name_override: Optional[str] = None, cls: Optional[Type] = None):
self.raw_f = f
self.cls = cls
# TODO(erikbern): if f.__qualname__ != f.__name__, we should infer the class name instead
self.function_name = name_override if name_override is not None else f.__qualname__

if name_override is not None:
self.function_name = name_override
elif f.__qualname__ != f.__name__ and not serialized:
# Class function.
if len(f.__qualname__.split(".")) > 2:
raise InvalidError("@stub.cls classes must be defined in global scope")
self.function_name = f"{cls.__name__}.{f.__name__}"
else:
self.function_name = f.__qualname__

self.signature = inspect.signature(f)
module = inspect.getmodule(f)

4 changes: 3 additions & 1 deletion modal/cli/token.py
Original file line number Diff line number Diff line change
@@ -56,7 +56,7 @@ def _new_token(
with Client.unauthenticated_client(server_url) as client:
token_flow = TokenFlow(client)

with token_flow.start(source, next_url) as (token_flow_id, web_url):
with token_flow.start(source, next_url) as (token_flow_id, web_url, code):
with console.status("Waiting for authentication in the web browser", spinner="dots"):
# Open the web url in the browser
if webbrowser.open_new_tab(web_url):
@@ -69,6 +69,8 @@ def _new_token(
" - please go to this URL manually and complete the flow:"
)
console.print(f"\n[link={web_url}]{web_url}[/link]\n")
if code:
console.print(f"Enter this code: [yellow]{code}[/yellow]\n")

with console.status("Waiting for token flow to complete...", spinner="dots") as status:
for attempt in itertools.count():
13 changes: 8 additions & 5 deletions modal/stub.py
Original file line number Diff line number Diff line change
@@ -609,11 +609,14 @@ def wrapper(user_cls: CLS_T) -> _Cls:
partial_functions: Dict[str, PartialFunction] = {}
functions: Dict[str, _Function] = {}

for k, v in user_cls.__dict__.items():
if isinstance(v, PartialFunction):
partial_functions[k] = v
partial_function = synchronizer._translate_in(v) # TODO: remove need for?
functions[k] = decorator(partial_function, user_cls)
for parent_cls in user_cls.mro():
if parent_cls is object:
continue
for k, v in parent_cls.__dict__.items():
if isinstance(v, PartialFunction):
partial_functions[k] = v
partial_function = synchronizer._translate_in(v) # TODO: remove need for?
functions[k] = decorator(partial_function, user_cls)

tag: str = user_cls.__name__
cls: _Cls = _Cls.from_local(user_cls, functions)
10 changes: 7 additions & 3 deletions modal/token_flow.py
Original file line number Diff line number Diff line change
@@ -19,10 +19,11 @@ def __init__(self, client: _Client):
@asynccontextmanager
async def start(
self, utm_source: Optional[str] = None, next_url: Optional[str] = None
) -> AsyncGenerator[Tuple[str, str], None]:
) -> AsyncGenerator[Tuple[str, str, str], None]:
"""mdmd:hidden"""
# Run a temporary http server returning the token id on /
# This helps us add direct validation later
# TODO(erikbern): handle failure launching server

async def slash(request):
headers = {"Access-Control-Allow-Origin": "*"}
@@ -42,14 +43,17 @@ async def slash(request):
)
resp = await self.stub.TokenFlowCreate(req)
self.token_flow_id = resp.token_flow_id
yield (resp.token_flow_id, resp.web_url)
self.wait_secret = resp.wait_secret
yield (resp.token_flow_id, resp.web_url, resp.code)

async def finish(
self, timeout: float = 40.0, grpc_extra_timeout: float = 5.0
) -> Optional[api_pb2.TokenFlowWaitResponse]:
"""mdmd:hidden"""
# Wait for token flow to finish
req = api_pb2.TokenFlowWaitRequest(token_flow_id=self.token_flow_id, timeout=timeout)
req = api_pb2.TokenFlowWaitRequest(
token_flow_id=self.token_flow_id, timeout=timeout, wait_secret=self.wait_secret
)
resp = await self.stub.TokenFlowWait(req, timeout=(timeout + grpc_extra_timeout))
if not resp.timeout:
return resp
17 changes: 14 additions & 3 deletions modal_proto/api.proto
Original file line number Diff line number Diff line change
@@ -90,9 +90,17 @@ enum WebhookAsyncMode {
WEBHOOK_ASYNC_MODE_AUTO = 4; // redirect to polling endpoint if execution time nears the http timeout
}

// An opinionated interpretation of a message in the ASGI HTTP specification.
// Spec: https://asgi.readthedocs.io/en/latest/specs/www.html
message AsgiProtocol {
// A web endpoint connection-related message.
//
// Modal's internal web endpoint runtime effectively acts as a global web server
// that schedules requests to tasks that are spawned on-demand, so we need an
// internal protocol to model HTTP requests. This is that protocol.
//
// We base this protocol on Python's ASGI specification, which is a popular
// interface between applications and web servers.
//
// ASGI Spec: https://asgi.readthedocs.io/en/latest/specs/www.html
message Asgi {
// Single-request connection scope
message Scope {
string http_version = 1;
@@ -1285,11 +1293,14 @@ message TokenFlowCreateRequest {
message TokenFlowCreateResponse {
string token_flow_id = 1;
string web_url = 2;
string code = 3;
string wait_secret = 4;
};

message TokenFlowWaitRequest {
float timeout = 1;
string token_flow_id = 2;
string wait_secret = 3;
}

message TokenFlowWaitResponse {
17 changes: 17 additions & 0 deletions modal_proto/options.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
// Defines custom options used internally at Modal.
// Custom options must be in the range 50000-99999.
// Reference: https://protobuf.dev/programming-guides/proto2/#customoptions
syntax = "proto3";

import "google/protobuf/descriptor.proto";

package modal.options;

extend google.protobuf.FieldOptions {
optional bool audit_target_attr = 50000;
}

extend google.protobuf.MethodOptions {
optional string audit_event_name = 50000;
optional string audit_event_description = 50001;
}
14 changes: 14 additions & 0 deletions modal_test_support/functions.py
Original file line number Diff line number Diff line change
@@ -191,3 +191,17 @@ async def sleep_700_async(x):

def unassociated_function(x):
return 100 - x


class BaseCls:
def __enter__(self):
self.x = 2

@method()
def run(self, y):
return self.x * y


@stub.cls()
class DerivedCls(BaseCls):
pass
2 changes: 1 addition & 1 deletion modal_version/_version_generated.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
# Copyright Modal Labs 2023
build_number = 3541
build_number = 3568
2 changes: 1 addition & 1 deletion tasks.py
Original file line number Diff line number Diff line change
@@ -28,7 +28,7 @@ def protoc(ctx):
+ " --python_out=. --grpclib_python_out=. --grpc_python_out=. --mypy_out=. --mypy_grpc_out=."
)
print(py_protoc)
ctx.run(f"{py_protoc} -I . modal_proto/api.proto")
ctx.run(f"{py_protoc} -I . " "modal_proto/api.proto " "modal_proto/options.proto ")


@task