diff --git a/src/marvin/openai/ChatCompletion/__init__.py b/src/marvin/openai/ChatCompletion/__init__.py index 38bbb3d22..339eef0b2 100644 --- a/src/marvin/openai/ChatCompletion/__init__.py +++ b/src/marvin/openai/ChatCompletion/__init__.py @@ -7,9 +7,10 @@ import warnings from typing import Type, Optional, Union, Literal +from pydantic import Extra -class ChatCompletionConfig(BaseSettings): +class ChatCompletionConfig(BaseSettings, extra=Extra.allow): model: str = "gpt-3.5-turbo" temperature: float = 0 functions: list = Field(default_factory=list) @@ -30,7 +31,7 @@ def merge(self, *args, **kwargs): setattr(self, key, getattr(self, key, []) + value) else: setattr(self, key, value) - return {k: v for k, v in self.__dict__.items() if v != []} + return {k: v for k, v in self.__dict__.items() if v} def process_list(lst): @@ -68,18 +69,19 @@ def create(cls, *args, response_model: Optional[Type[BaseModel]] = None, **kwarg } payload = config.merge(**kwargs) response = cls.observer(super(ChatCompletion, cls).create)(*args, **payload) - response.to_model = lambda: ( - process_list( - list( - map( - lambda x: response_model.parse_raw( - x.message.function_call.arguments - ), - response.choices, + if response_model is not None: + response.to_model = lambda: ( + process_list( + list( + map( + lambda x: response_model.parse_raw( + x.message.function_call.arguments + ), + response.choices, + ) ) ) ) - ) return response @classmethod @@ -105,18 +107,19 @@ async def acreate( response = await cls.observer(super(ChatCompletion, cls).acreate)( *args, **payload ) - response.to_model = lambda: ( - process_list( - list( - map( - lambda x: response_model.parse_raw( - x.message.function_call.arguments - ), - response.choices, + if response_model is not None: + response.to_model = lambda: ( + process_list( + list( + map( + lambda x: response_model.parse_raw( + x.message.function_call.arguments + ), + response.choices, + ) ) ) ) - ) return response @staticmethod