fix: handle none connectionstr (#110)

pull/116/head
Laurel Orr 10 months ago committed by GitHub
parent d94101964f
commit a6a774bfb8
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -42,17 +42,19 @@ class AzureClient(OpenAIClient):
connection_str: connection string.
client_args: client arguments.
"""
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
elif len(connection_parts) == 2:
self.api_key, self.host = connection_parts
else:
raise ValueError(
"Invalid connection string. "
"Must be either AZURE_OPENAI_KEY or "
"AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.api_key, self.host = None, None
if connection_str:
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
elif len(connection_parts) == 2:
self.api_key, self.host = connection_parts
else:
raise ValueError(
"Invalid connection string. "
"Must be either AZURE_OPENAI_KEY or "
"AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.api_key = self.api_key or os.environ.get("AZURE_OPENAI_KEY")
if self.api_key is None:
raise ValueError(
@ -60,7 +62,6 @@ class AzureClient(OpenAIClient):
"variable or pass through `client_connection`."
)
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
self.host = self.host.rstrip("/")
if self.host is None:
raise ValueError(
"Azure Service URL not set "
@ -68,6 +69,7 @@ class AzureClient(OpenAIClient):
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.host = self.host.rstrip("/")
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in OPENAI_ENGINES:

@ -44,17 +44,19 @@ class AzureChatClient(OpenAIChatClient):
connection_str: connection string.
client_args: client arguments.
"""
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
elif len(connection_parts) == 2:
self.api_key, self.host = connection_parts
else:
raise ValueError(
"Invalid connection string. "
"Must be either AZURE_OPENAI_KEY or "
"AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.api_key, self.host = None, None
if connection_str:
connection_parts = connection_str.split("::")
if len(connection_parts) == 1:
self.api_key = connection_parts[0]
elif len(connection_parts) == 2:
self.api_key, self.host = connection_parts
else:
raise ValueError(
"Invalid connection string. "
"Must be either AZURE_OPENAI_KEY or "
"AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.api_key = self.api_key or os.environ.get("AZURE_OPENAI_KEY")
if self.api_key is None:
raise ValueError(
@ -62,7 +64,6 @@ class AzureChatClient(OpenAIChatClient):
"variable or pass through `client_connection`."
)
self.host = self.host or os.environ.get("AZURE_OPENAI_ENDPOINT")
self.host = self.host.rstrip("/")
if self.host is None:
raise ValueError(
"Azure Service URL not set "
@ -70,6 +71,7 @@ class AzureChatClient(OpenAIChatClient):
" Set AZURE_OPENAI_ENDPOINT or pass through `client_connection`."
" as AZURE_OPENAI_KEY::AZURE_OPENAI_ENDPOINT"
)
self.host = self.host.rstrip("/")
for key in self.PARAMS:
setattr(self, key, client_args.pop(key, self.PARAMS[key][1]))
if getattr(self, "engine") not in OPENAICHAT_ENGINES:

Loading…
Cancel
Save