Skip to content

Commit

Permalink
fix(api): performance (#3052)
Browse files Browse the repository at this point in the history
Co-authored-by: Tal <[email protected]>
  • Loading branch information
shahargl and talboren authored Jan 19, 2025
1 parent 2a453f0 commit 02e90af
Show file tree
Hide file tree
Showing 12 changed files with 215 additions and 23 deletions.
8 changes: 8 additions & 0 deletions docker/Dockerfile.api
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,14 @@ ENV PYTHONFAULTHANDLER=1 \
PYTHONHASHSEED=random \
PYTHONUNBUFFERED=1

# THIS IS FOR DEBUGGING PURPOSES
# RUN apt-get update && \
# apt-get install -y --no-install-recommends \
# iproute2 \
# net-tools \
# procps && \
# rm -rf /var/lib/apt/lists/*

RUN useradd --user-group --system --create-home --no-log-init keep
WORKDIR /app

Expand Down
2 changes: 1 addition & 1 deletion keep/api/alert_deduplicator/alert_deduplicator.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def _apply_deduplication_rule(
)
alert.isPartialDuplicate = True
else:
self.logger.info(
self.logger.debug(
"Alert is not deduplicated",
extra={
"alert_id": alert.id,
Expand Down
64 changes: 64 additions & 0 deletions keep/api/api.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import asyncio
import logging
import os
import time
from contextlib import asynccontextmanager
from functools import wraps
from importlib import metadata
from typing import Awaitable, Callable

import requests
import uvicorn
Expand All @@ -13,6 +16,7 @@
from prometheus_fastapi_instrumentator import Instrumentator
from slowapi import _rate_limit_exceeded_handler
from slowapi.errors import RateLimitExceeded
from slowapi.middleware import SlowAPIMiddleware
from starlette.middleware.cors import CORSMiddleware
from starlette_context import plugins
from starlette_context.middleware import RawContextMiddleware
Expand Down Expand Up @@ -79,6 +83,8 @@
CONSUMER = config("CONSUMER", default="true", cast=bool)
TOPOLOGY = config("KEEP_TOPOLOGY_PROCESSOR", default="false", cast=bool)
KEEP_DEBUG_TASKS = config("KEEP_DEBUG_TASKS", default="false", cast=bool)
KEEP_DEBUG_MIDDLEWARES = config("KEEP_DEBUG_MIDDLEWARES", default="false", cast=bool)
KEEP_USE_LIMITER = config("KEEP_USE_LIMITER", default="false", cast=bool)

AUTH_TYPE = config("AUTH_TYPE", default=IdentityManagerTypes.NOAUTH.value).lower()
try:
Expand Down Expand Up @@ -329,6 +335,8 @@ async def catch_exception(request: Request, exc: Exception):
)

app.add_middleware(LoggingMiddleware)
if KEEP_USE_LIMITER:
app.add_middleware(SlowAPIMiddleware)

if config("KEEP_METRICS", default="true", cast=bool):
Instrumentator(
Expand All @@ -339,6 +347,62 @@ async def catch_exception(request: Request, exc: Exception):
if config("KEEP_OTEL_ENABLED", default="true", cast=bool):
keep.api.observability.setup(app)

# if debug middlewares are enabled, instrument them
if KEEP_DEBUG_MIDDLEWARES:
logger.info("Instrumenting middlewares")
app = instrument_middleware(app)
logger.info("Instrumented middlewares")
return app


logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


# SHAHAR:
# This (and instrument_middleware) is a helper function to wrap the call of a middleware with timing
# It will log the time it took for the middleware to run
# It should NOT be used in production!
def wrap_call(middleware_cls, original_call):
# if the call is already wrapped, return it
if hasattr(original_call, "_timing_wrapped"):
return original_call

@wraps(original_call)
async def timed_call(
self,
scope: dict,
receive: Callable[[], Awaitable[dict]],
send: Callable[[dict], Awaitable[None]],
):
if scope["type"] != "http":
return await original_call(self, scope, receive, send)

start_time = time.time()
try:
response = await original_call(self, scope, receive, send)
return response
finally:
process_time = (time.time() - start_time) * 1000
path = scope.get("path", "")
method = scope.get("method", "")
middleware_name = self.__class__.__name__
logger.info(
f"⏱️ {middleware_name:<40} {method} {path} took {process_time:>8.2f}ms"
)

timed_call._timing_wrapped = True
return timed_call


def instrument_middleware(app):
# Get middleware from FastAPI app
for middleware in app.user_middleware:
if hasattr(middleware.cls, "__call__"):
original_call = middleware.cls.__call__
middleware.cls.__call__ = wraps(original_call)(
wrap_call(middleware.cls, original_call)
)
return app


Expand Down
5 changes: 4 additions & 1 deletion keep/api/core/dependencies.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ async def extract_generic_body(request: Request) -> dict | bytes | FormData:


def get_pusher_client() -> Pusher | None:
logger.debug("Getting pusher client")
pusher_disabled = os.environ.get("PUSHER_DISABLED", "false") == "true"
pusher_host = os.environ.get("PUSHER_HOST")
pusher_app_id = os.environ.get("PUSHER_APP_ID")
Expand All @@ -53,7 +54,7 @@ def get_pusher_client() -> Pusher | None:
return None

# TODO: defaults on open source no docker
return Pusher(
pusher = Pusher(
host=pusher_host,
port=(
int(os.environ.get("PUSHER_PORT"))
Expand All @@ -66,3 +67,5 @@ def get_pusher_client() -> Pusher | None:
ssl=False if os.environ.get("PUSHER_USE_SSL", False) is False else True,
cluster=os.environ.get("PUSHER_CLUSTER"),
)
logging.debug("Pusher client initialized")
return pusher
6 changes: 5 additions & 1 deletion keep/api/core/limiter.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@

logger = logging.getLogger(__name__)
limiter_enabled = config("KEEP_USE_LIMITER", default="false", cast=bool)
default_limit = config("KEEP_LIMIT_CONCURRENCY", default="100/minute", cast=str)

logger.warning(f"Rate limiter is {'enabled' if limiter_enabled else 'disabled'}")

limiter = Limiter(key_func=get_remote_address, enabled=limiter_enabled)
limiter = Limiter(
key_func=get_remote_address, enabled=limiter_enabled, default_limits=[default_limit]
)
7 changes: 6 additions & 1 deletion keep/api/logging.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,7 +229,7 @@ def format(self, record):
},
"dev_terminal": {
"()": DevTerminalFormatter,
"format": "%(asctime)s - %(thread)s %(threadName)s %(levelname)s - %(message)s",
"format": "%(asctime)s - %(thread)s %(otelTraceID)s %(threadName)s %(levelname)s - %(message)s",
},
},
"handlers": {
Expand All @@ -255,6 +255,11 @@ def format(self, record):
"level": LOG_LEVEL,
"propagate": False,
},
"slowapi": {
"handlers": ["default"],
"level": LOG_LEVEL,
"propagate": False,
},
# shut the open telemetry logger down since it keep pprints <Token var=<ContextVar name='current_context' default={} at was created in a different Context
# https://github.com/open-telemetry/opentelemetry-python/issues/2606
"opentelemetry.context": {
Expand Down
12 changes: 9 additions & 3 deletions keep/api/middlewares.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware

from keep.api.core.config import config
from keep.api.core.db import get_api_key

logger = logging.getLogger(__name__)
Expand All @@ -15,6 +16,8 @@
except Exception:
KEEP_VERSION = os.environ.get("KEEP_VERSION", "unknown")

KEEP_EXTRACT_IDENTITY = config("KEEP_EXTRACT_IDENTITY", default="true", cast=bool)


def _extract_identity(request: Request, attribute="email") -> str:
try:
Expand All @@ -28,9 +31,12 @@ def _extract_identity(request: Request, attribute="email") -> str:
if not api_key:
return "anonymous"

api_key = get_api_key(api_key)
if api_key:
return api_key.tenant_id
# allow disabling the extraction of the identity from the api key
# for high performance scenarios
if KEEP_EXTRACT_IDENTITY:
api_key = get_api_key(api_key)
if api_key:
return api_key.tenant_id
return "anonymous"
except Exception:
return "anonymous"
Expand Down
14 changes: 2 additions & 12 deletions keep/api/routes/alerts.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import celpy
from arq import ArqRedis
from fastapi import APIRouter, BackgroundTasks, Depends, HTTPException, Query, Request
from fastapi import APIRouter, Depends, HTTPException, Query, Request
from fastapi.responses import JSONResponse
from pusher import Pusher

Expand All @@ -28,7 +28,6 @@
)
from keep.api.core.dependencies import extract_generic_body, get_pusher_client
from keep.api.core.elastic import ElasticClient
from keep.api.core.limiter import limiter
from keep.api.core.metrics import running_tasks_by_process_gauge, running_tasks_gauge
from keep.api.models.alert import (
AlertDto,
Expand All @@ -53,7 +52,7 @@
logger = logging.getLogger(__name__)

REDIS = os.environ.get("REDIS", "false") == "true"
EVENT_WORKERS = int(config("KEEP_EVENT_WORKERS", default=50, cast=int))
EVENT_WORKERS = int(config("KEEP_EVENT_WORKERS", default=5, cast=int))

# Create dedicated threadpool
process_event_executor = ThreadPoolExecutor(
Expand Down Expand Up @@ -316,7 +315,6 @@ def discard_future(


def create_process_event_task(
bg_tasks: BackgroundTasks,
tenant_id: str,
provider_type: str | None,
provider_id: str | None,
Expand Down Expand Up @@ -358,16 +356,13 @@ def create_process_event_task(
response_model=AlertDto | list[AlertDto],
status_code=202,
)
@limiter.limit(config("KEEP_LIMIT_CONCURRENCY", default="100/minute", cast=str))
async def receive_generic_event(
event: AlertDto | list[AlertDto] | dict,
bg_tasks: BackgroundTasks,
request: Request,
fingerprint: str | None = None,
authenticated_entity: AuthenticatedEntity = Depends(
IdentityManagerFactory.get_auth_verifier(["write:alert"])
),
pusher_client: Pusher = Depends(get_pusher_client),
):
"""
A generic webhook endpoint that can be used by any provider to send alerts to Keep.
Expand Down Expand Up @@ -402,7 +397,6 @@ async def receive_generic_event(
task_name = job.job_id
else:
task_name = create_process_event_task(
bg_tasks,
authenticated_entity.tenant_id,
None,
None,
Expand Down Expand Up @@ -447,18 +441,15 @@ async def webhook_challenge():
description="Receive an alert event from a provider",
status_code=202,
)
@limiter.limit(config("KEEP_LIMIT_CONCURRENCY", default="100/minute", cast=str))
async def receive_event(
provider_type: str,
bg_tasks: BackgroundTasks,
request: Request,
provider_id: str | None = None,
fingerprint: str | None = None,
event=Depends(extract_generic_body),
authenticated_entity: AuthenticatedEntity = Depends(
IdentityManagerFactory.get_auth_verifier(["write:alert"])
),
pusher_client: Pusher = Depends(get_pusher_client),
) -> dict[str, str]:
trace_id = request.state.trace_id
running_tasks: set = request.state.background_tasks
Expand Down Expand Up @@ -512,7 +503,6 @@ async def receive_event(
task_name = job.job_id
else:
task_name = create_process_event_task(
bg_tasks,
authenticated_entity.tenant_id,
provider_type,
provider_id,
Expand Down
2 changes: 2 additions & 0 deletions keep/api/routes/providers.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@

from keep.api.core.config import config
from keep.api.core.db import count_alerts, get_provider_distribution, get_session
from keep.api.core.limiter import limiter
from keep.api.models.db.provider import Provider
from keep.api.models.provider import Provider as ProviderDTO
from keep.api.models.provider import ProviderAlertsCountResponseDTO
Expand Down Expand Up @@ -133,6 +134,7 @@ def get_provider_logs(
description="export all installed providers",
response_model=list[ProviderDTO],
)
@limiter.exempt
def get_installed_providers(
authenticated_entity: AuthenticatedEntity = Depends(
IdentityManagerFactory.get_auth_verifier(["read:providers"])
Expand Down
Loading

0 comments on commit 02e90af

Please sign in to comment.