Standardized openai init params (#21739)

## Patch Summary
community:openai[patch]: standardize init args

## Details
I made changes to the OpenAI Chat API wrapper test in the Langchain
open-source repository

- **File**: `libs/community/tests/unit_tests/chat_models/test_openai.py`
- **Changes**:
  - Updated `max_retries` with Pydantic Field
  - Updated the corresponding unit test
- **Related Issues**: #20085
  - Updated max_retries with Pydantic Field, updated the unit test.

---------

Co-authored-by: JuHyung Son <sonju0427@gmail.com>
pull/21773/head
Kyle Cassidy 2 weeks ago committed by GitHub
parent c03fd93fc1
commit eca8c4bcc6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,4 +1,5 @@
"""OpenAI chat wrapper."""
from __future__ import annotations
import logging
@ -217,7 +218,7 @@ class ChatOpenAI(BaseChatModel):
)
"""Timeout for requests to OpenAI completion API. Can be float, httpx.Timeout or
None."""
max_retries: int = 2
max_retries: int = Field(default=2)
"""Maximum number of retries to make when generating."""
streaming: bool = False
"""Whether to stream the results or not."""

@ -1,6 +1,7 @@
"""Test OpenAI Chat API wrapper."""
import json
from typing import Any
from typing import Any, List
from unittest.mock import MagicMock, patch
import pytest
@ -17,10 +18,19 @@ from langchain_community.chat_models.openai import ChatOpenAI
@pytest.mark.requires("openai")
def test_openai_model_param() -> None:
llm = ChatOpenAI(model="foo", openai_api_key="foo") # type: ignore[call-arg]
assert llm.model_name == "foo"
llm = ChatOpenAI(model_name="foo", openai_api_key="foo") # type: ignore[call-arg]
assert llm.model_name == "foo"
test_cases: List[dict] = [
{"model_name": "foo", "openai_api_key": "foo"},
{"model": "foo", "openai_api_key": "foo"},
{"model_name": "foo", "api_key": "foo"},
{"model_name": "foo", "openai_api_key": "foo", "max_retries": 2},
]
for case in test_cases:
llm = ChatOpenAI(**case)
assert llm.model_name == "foo", "Model name should be 'foo'"
assert llm.openai_api_key == "foo", "API key should be 'foo'"
assert hasattr(llm, "max_retries"), "max_retries attribute should exist"
assert llm.max_retries == 2, "max_retries default should be set to 2"
def test_function_message_dict_to_function_message() -> None:

Loading…
Cancel
Save