-
Notifications
You must be signed in to change notification settings - Fork 1.5k
Consistently raise ContentFilterError when model response is empty because of content filter
#3634
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
4ac81d4
658f407
404a833
3771c79
28e20c8
70bcb74
e861a50
51506b2
b67d216
4ac1608
264f9e6
ed242de
a0a85cb
0327196
bba451d
d969325
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -24,6 +24,7 @@ | |||||
| 'UsageLimitExceeded', | ||||||
| 'ModelAPIError', | ||||||
| 'ModelHTTPError', | ||||||
| 'ContentFilterError', | ||||||
| 'IncompleteToolCall', | ||||||
| 'FallbackExceptionGroup', | ||||||
| ) | ||||||
|
|
@@ -152,6 +153,10 @@ def __str__(self) -> str: | |||||
| return self.message | ||||||
|
|
||||||
|
|
||||||
| class ContentFilterError(UnexpectedModelBehavior): | ||||||
| """Raised when content filtering is triggered by the model provider.""" | ||||||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
|
|
||||||
| class ModelAPIError(AgentRunError): | ||||||
| """Raised when a model provider API request fails.""" | ||||||
|
|
||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -503,11 +503,7 @@ def _process_response(self, response: GenerateContentResponse) -> ModelResponse: | |
| finish_reason = _FINISH_REASON_MAP.get(raw_finish_reason) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Are you able to test against Vertex AI? It may be worth exposing these extra data points on
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Similar for the {
'message': "The response was filtered due to the prompt triggering Azure OpenAI's content management policy. Please modify your prompt and retry. To learn more about our content filtering policies please read our documentation: https://go.microsoft.com/fwlink/?linkid=2198766",
'type': None,
'param': 'prompt',
'code': 'content_filter',
'status': 400,
'innererror': {
'code': 'ResponsibleAIPolicyViolation',
'content_filter_result': {
'hate': {
'filtered': True,
'severity': 'high'
},
'jailbreak': {
'filtered': False,
'detected': False
},
'self_harm': {
'filtered': False,
'severity': 'safe'
},
'sexual': {
'filtered': False,
'severity': 'safe'
},
'violence': {
'filtered': False,
'severity': 'medium'
}
}
}
}
|
||
|
|
||
| if candidate.content is None or candidate.content.parts is None: | ||
| if finish_reason == 'content_filter' and raw_finish_reason: | ||
| raise UnexpectedModelBehavior( | ||
| f'Content filter {raw_finish_reason.value!r} triggered', response.model_dump_json() | ||
| ) | ||
| parts = [] # pragma: no cover | ||
| parts = [] | ||
| else: | ||
| parts = candidate.content.parts or [] | ||
|
|
||
|
|
@@ -707,12 +703,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | |
| yield self._parts_manager.handle_part(vendor_part_id=uuid4(), part=web_fetch_return) | ||
|
|
||
| if candidate.content is None or candidate.content.parts is None: | ||
| if self.finish_reason == 'content_filter' and raw_finish_reason: # pragma: no cover | ||
| raise UnexpectedModelBehavior( | ||
| f'Content filter {raw_finish_reason.value!r} triggered', chunk.model_dump_json() | ||
| ) | ||
| else: # pragma: no cover | ||
| continue | ||
| continue | ||
|
|
||
| parts = candidate.content.parts | ||
| if not parts: | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -176,6 +176,20 @@ def _resolve_openai_image_generation_size( | |
| return mapped_size | ||
|
|
||
|
|
||
| def _check_azure_content_filter(e: APIStatusError) -> bool: | ||
| """Check if the error is an Azure content filter error.""" | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we do this only if |
||
| if e.status_code == 400: | ||
| body_any: Any = e.body | ||
|
|
||
AlanPonnachan marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| if isinstance(body_any, dict): | ||
| body_dict = cast(dict[str, Any], body_any) | ||
|
|
||
| if (error := body_dict.get('error')) and isinstance(error, dict): | ||
| error_dict = cast(dict[str, Any], error) | ||
| return error_dict.get('code') == 'content_filter' | ||
| return False | ||
|
|
||
|
|
||
| class OpenAIChatModelSettings(ModelSettings, total=False): | ||
| """Settings used for an OpenAI model request.""" | ||
|
|
||
|
|
@@ -584,6 +598,20 @@ async def _completions_create( | |
| extra_body=model_settings.get('extra_body'), | ||
| ) | ||
| except APIStatusError as e: | ||
| if _check_azure_content_filter(e): | ||
| return chat.ChatCompletion( | ||
| id='content_filter', | ||
| choices=[ | ||
| chat.chat_completion.Choice( | ||
| finish_reason='content_filter', | ||
| index=0, | ||
| message=chat.ChatCompletionMessage(content=None, role='assistant'), | ||
| ) | ||
| ], | ||
| created=0, | ||
| model=self.model_name, | ||
| object='chat.completion', | ||
| ) | ||
| if (status_code := e.status_code) >= 400: | ||
| raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e | ||
| raise # pragma: lax no cover | ||
|
|
@@ -631,6 +659,7 @@ def _process_response(self, response: chat.ChatCompletion | str) -> ModelRespons | |
| raise UnexpectedModelBehavior(f'Invalid response from {self.system} chat completions endpoint: {e}') from e | ||
|
|
||
| choice = response.choices[0] | ||
|
|
||
| items: list[ModelResponsePart] = [] | ||
|
|
||
| if thinking_parts := self._process_thinking(choice.message): | ||
|
|
@@ -1431,6 +1460,19 @@ async def _responses_create( # noqa: C901 | |
| extra_body=model_settings.get('extra_body'), | ||
| ) | ||
| except APIStatusError as e: | ||
| if _check_azure_content_filter(e): | ||
| return responses.Response( | ||
| id='content_filter', | ||
| model=self.model_name, | ||
| created_at=0, | ||
| object='response', | ||
| status='incomplete', | ||
| incomplete_details={'reason': 'content_filter'}, # type: ignore | ||
| output=[], | ||
| parallel_tool_calls=False, | ||
| tool_choice='auto', | ||
| tools=[], | ||
| ) | ||
| if (status_code := e.status_code) >= 400: | ||
| raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e | ||
| raise # pragma: lax no cover | ||
|
|
@@ -2089,6 +2131,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: | |
| raw_finish_reason = ( | ||
| details.reason if (details := chunk.response.incomplete_details) else chunk.response.status | ||
| ) | ||
|
|
||
| if raw_finish_reason: # pragma: no branch | ||
| self.provider_details = {'finish_reason': raw_finish_reason} | ||
| self.finish_reason = _RESPONSES_FINISH_REASON_MAP.get(raw_finish_reason) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -54,7 +54,13 @@ | |
| WebFetchTool, | ||
| WebSearchTool, | ||
| ) | ||
| from pydantic_ai.exceptions import ModelAPIError, ModelHTTPError, ModelRetry, UnexpectedModelBehavior, UserError | ||
| from pydantic_ai.exceptions import ( | ||
| ContentFilterError, | ||
| ModelAPIError, | ||
| ModelHTTPError, | ||
| ModelRetry, | ||
| UserError, | ||
| ) | ||
| from pydantic_ai.messages import ( | ||
| BuiltinToolCallEvent, # pyright: ignore[reportDeprecated] | ||
| BuiltinToolResultEvent, # pyright: ignore[reportDeprecated] | ||
|
|
@@ -994,7 +1000,10 @@ async def test_google_model_safety_settings(allow_model_requests: None, google_p | |
| ) | ||
| agent = Agent(m, instructions='You hate the world!', model_settings=settings) | ||
|
|
||
| with pytest.raises(UnexpectedModelBehavior, match="Content filter 'SAFETY' triggered"): | ||
| with pytest.raises( | ||
| ContentFilterError, | ||
| match='Content filter triggered for model gemini-1.5-flash', | ||
| ): | ||
| await agent.run('Tell me a joke about a Brazilians.') | ||
|
|
||
|
|
||
|
|
@@ -4610,3 +4619,57 @@ def get_country() -> str: | |
| ), | ||
| ] | ||
| ) | ||
|
|
||
|
|
||
| async def test_google_stream_empty_chunk( | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need this? |
||
| allow_model_requests: None, google_provider: GoogleProvider, mocker: MockerFixture | ||
| ): | ||
| """Test that empty chunks in the stream are ignored (coverage for continue).""" | ||
| model_name = 'gemini-2.5-flash' | ||
| model = GoogleModel(model_name, provider=google_provider) | ||
|
|
||
| # Chunk with NO content | ||
| empty_candidate = mocker.Mock(finish_reason=None, content=None) | ||
| empty_candidate.grounding_metadata = None | ||
| empty_candidate.url_context_metadata = None | ||
|
|
||
| chunk_empty = mocker.Mock( | ||
| candidates=[empty_candidate], model_version=model_name, usage_metadata=None, create_time=datetime.datetime.now() | ||
| ) | ||
| chunk_empty.model_dump_json.return_value = '{}' | ||
|
|
||
| # Chunk WITH content (valid) | ||
| part_mock = mocker.Mock( | ||
| text='Hello', | ||
| thought=False, | ||
| function_call=None, | ||
| inline_data=None, | ||
| executable_code=None, | ||
| code_execution_result=None, | ||
| ) | ||
| part_mock.thought_signature = None | ||
|
|
||
| valid_candidate = mocker.Mock( | ||
| finish_reason=GoogleFinishReason.STOP, | ||
| content=mocker.Mock(parts=[part_mock]), | ||
| grounding_metadata=None, | ||
| url_context_metadata=None, | ||
| ) | ||
|
|
||
| chunk_valid = mocker.Mock( | ||
| candidates=[valid_candidate], model_version=model_name, usage_metadata=None, create_time=datetime.datetime.now() | ||
| ) | ||
| chunk_valid.model_dump_json.return_value = '{"content": "Hello"}' | ||
|
|
||
| async def stream_iterator(): | ||
| yield chunk_empty | ||
| yield chunk_valid | ||
|
|
||
| mocker.patch.object(model.client.aio.models, 'generate_content_stream', return_value=stream_iterator()) | ||
|
|
||
| agent = Agent(model=model) | ||
|
|
||
| async with agent.run_stream('hello') as result: | ||
| output = await result.get_output() | ||
|
|
||
| assert output == 'Hello' | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Instead of naming the model (which we don't do anywhere else), I'd prefer to include
model_response.provider_details['finish_reason']if it exists, like we did in the Google exceptionThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Related to what I wrote in google.py about storing more of the actual error response context on
provider_details, what do you think about including (JSONified)model_responseon this exception, so that it can be accessed by the user (or in an observability platform)?