diff --git a/src/agents/agent.py b/src/agents/agent.py
index 2723e678..28fcbaeb 100644
--- a/src/agents/agent.py
+++ b/src/agents/agent.py
@@ -93,7 +93,15 @@ class Agent(Generic[TContext]):
     modularity.
     """
 
-    model: str | Model | None = None
+    model: (
+        str
+        | Model
+        | Callable[
+            [RunContextWrapper[TContext], Agent[TContext]],
+            MaybeAwaitable[str | Model],
+        ]
+        | None
+    ) = None
     """The model implementation to use when invoking the LLM.
 
     By default, if not set, the agent will use the default model configured in
@@ -205,3 +213,17 @@ async def get_system_prompt(self, run_context: RunContextWrapper[TContext]) -> s
             logger.error(f"Instructions must be a string or a function, got {self.instructions}")
 
         return None
+
+    async def get_model(self, run_context: RunContextWrapper[TContext]) -> str | Model | None:
+        """Get the model for the agent."""
+        if isinstance(self.model, (str, Model)):
+            return self.model
+        elif callable(self.model):
+            if inspect.iscoroutinefunction(self.model):
+                return await cast(Awaitable[str | Model], self.model(run_context, self))
+            else:
+                return cast(str | Model, self.model(run_context, self))
+        elif self.model is not None:
+            logger.error(f"Model must be a string, Model object, or a function, got {self.model}")
+
+        return None
diff --git a/src/agents/run.py b/src/agents/run.py
index 934400fe..53043eed 100644
--- a/src/agents/run.py
+++ b/src/agents/run.py
@@ -628,7 +628,7 @@ async def _run_single_turn_streamed(
 
         handoffs = cls._get_handoffs(agent)
 
-        model = cls._get_model(agent, run_config)
+        model = await cls._get_model(agent, run_config, context_wrapper)
         model_settings = agent.model_settings.resolve(run_config.model_settings)
         final_response: ModelResponse | None = None
 
@@ -857,7 +857,7 @@ async def _get_new_response(
         context_wrapper: RunContextWrapper[TContext],
         run_config: RunConfig,
     ) -> ModelResponse:
-        model = cls._get_model(agent, run_config)
+        model = await cls._get_model(agent, run_config, context_wrapper)
         model_settings = agent.model_settings.resolve(run_config.model_settings)
         new_response = await model.get_response(
             system_instructions=system_prompt,
@@ -893,12 +893,22 @@ def _get_handoffs(cls, agent: Agent[Any]) -> list[Handoff]:
         return handoffs
 
     @classmethod
-    def _get_model(cls, agent: Agent[Any], run_config: RunConfig) -> Model:
+    async def _get_model(
+        cls,
+        agent: Agent[Any],
+        run_config: RunConfig,
+        context_wrapper: RunContextWrapper[TContext],
+    ) -> Model:
         if isinstance(run_config.model, Model):
             return run_config.model
         elif isinstance(run_config.model, str):
             return run_config.model_provider.get_model(run_config.model)
         elif isinstance(agent.model, Model):
             return agent.model
+        elif callable(agent.model):
+            model = await agent.get_model(context_wrapper)
+            if isinstance(model, Model):
+                return model
+            return run_config.model_provider.get_model(model)
 
         return run_config.model_provider.get_model(agent.model)
diff --git a/tests/test_agent_config.py b/tests/test_agent_config.py
index 44339dad..8ed81c3c 100644
--- a/tests/test_agent_config.py
+++ b/tests/test_agent_config.py
@@ -27,6 +27,29 @@ async def async_instructions(agent: Agent[None], context: RunContextWrapper[None
     assert await agent.get_system_prompt(context) == "async_123"
 
 
+@pytest.mark.asyncio
+async def test_model():
+    agent = Agent[None](
+        name="test",
+        model="gpt-4",
+    )
+    context = RunContextWrapper(None)
+
+    assert await agent.get_model(context) == "gpt-4"
+
+    def sync_model(context: RunContextWrapper[None], agent: Agent[None]):
+        return "sync-model"
+
+    agent = agent.clone(model=sync_model)
+    assert await agent.get_model(context) == "sync-model"
+
+    async def async_model(context: RunContextWrapper[None], agent: Agent[None]):
+        return "async-model"
+
+    agent = agent.clone(model=async_model)
+    assert await agent.get_model(context) == "async-model"
+
+
 @pytest.mark.asyncio
 async def test_handoff_with_agents():
     agent_1 = Agent(