Skip to content

Commit

Permalink
bug: Upgrade Langchain, AWS Powertools, Pydantic. Fix config without …
Browse files Browse the repository at this point in the history
…default embedding and step function deployment. (#598)
  • Loading branch information
charles-marion authored Oct 31, 2024
1 parent 9607ef8 commit bd520f6
Show file tree
Hide file tree
Showing 26 changed files with 145 additions and 79 deletions.
6 changes: 6 additions & 0 deletions lib/chatbot-api/functions/api-handler/common/constant.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,10 @@
SAFE_STR_REGEX = r"^[A-Za-z0-9-_. ]*$"
SAFE_HTTP_STR_REGEX = r"^[A-Za-z0-9-_.:/]*$"
ID_FIELD_VALIDATION = Field(min_length=1, max_length=100, pattern=SAFE_STR_REGEX)
ID_FIELD_VALIDATION_OPTIONAL = Field(
min_length=1, max_length=100, pattern=SAFE_STR_REGEX, default=None
)
SAFE_SHORT_STR_VALIDATION = Field(min_length=1, max_length=100, pattern=SAFE_STR_REGEX)
SAFE_SHORT_STR_VALIDATION_OPTIONAL = Field(
min_length=1, max_length=100, pattern=SAFE_STR_REGEX, default=None
)
5 changes: 3 additions & 2 deletions lib/chatbot-api/functions/api-handler/index.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,8 +52,9 @@ def handler(event: dict, context: LambdaContext) -> dict:
)
return app.resolve(event, context)
except ValidationError as e:
logger.warning(e.errors())
raise e
errors = e.errors(include_url=False, include_context=False, include_input=False)
logger.warning("Validation error", errors=errors)
raise ValueError(f"Invalid request. Details: {errors}")
except CommonError as e:
logger.warning(str(e))
raise e
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

class CrossEncodersRequest(BaseModel):
provider: str = Field(min_length=1, max_length=500, pattern=SAFE_STR_REGEX)
model: str = Field(min_length=1, max_length=500, pattern=SAFE_STR_REGEX)
model: str = Field(min_length=1, max_length=500, pattern=r"^[A-Za-z0-9-_. /]*$")
reference: str = Field(min_length=1, max_length=MAX_STR_INPUT_LENGTH)
passages: List[Annotated[str, Field(min_length=1, max_length=MAX_STR_INPUT_LENGTH)]]

Expand Down
15 changes: 9 additions & 6 deletions lib/chatbot-api/functions/api-handler/routes/documents.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from common.constant import (
ID_FIELD_VALIDATION,
ID_FIELD_VALIDATION_OPTIONAL,
SAFE_HTTP_STR_REGEX,
SAFE_STR_REGEX,
MAX_STR_INPUT_LENGTH,
Expand All @@ -19,9 +20,13 @@
router = Router()
logger = Logger()

CONTENT_TYPE_VALDIATION = Field(
min_length=1, max_length=100, pattern=r"^[A-Za-z0-9-_./]*$"
)


class FileUploadRequest(BaseModel):
workspaceId: Optional[str] = ID_FIELD_VALIDATION
workspaceId: Optional[str] = ID_FIELD_VALIDATION_OPTIONAL
fileName: str = Field(min_length=1, max_length=500, pattern=SAFE_STR_REGEX)


Expand All @@ -43,7 +48,7 @@ class WebsiteDocumentRequest(BaseModel):
address: str = Field(min_length=1, max_length=500, pattern=SAFE_HTTP_STR_REGEX)
followLinks: bool
limit: int = Field(gt=-1)
contentTypes: Optional[List[Annotated[str, SAFE_SHORT_STR_VALIDATION]]]
contentTypes: Optional[List[Annotated[str, CONTENT_TYPE_VALDIATION]]] = None


class RssFeedDocumentRequest(BaseModel):
Expand All @@ -59,16 +64,14 @@ class RssFeedDocumentRequest(BaseModel):
default=None, min_length=1, max_length=100, pattern=SAFE_STR_REGEX
)
followLinks: bool
contentTypes: Optional[List[Annotated[str, SAFE_SHORT_STR_VALIDATION]]]
contentTypes: Optional[List[Annotated[str, CONTENT_TYPE_VALDIATION]]] = None


class RssFeedCrawlerUpdateRequest(BaseModel):
documentType: str = SAFE_SHORT_STR_VALIDATION
followLinks: bool
limit: int = Field(lt=500)
contentTypes: Optional[Annotated[str, SAFE_SHORT_STR_VALIDATION]] = Field(
min_length=1, max_length=100, pattern=SAFE_STR_REGEX
)
contentTypes: Optional[List[Annotated[str, CONTENT_TYPE_VALDIATION]]] = None


class ListDocumentsRequest(BaseModel):
Expand Down
2 changes: 1 addition & 1 deletion lib/chatbot-api/functions/api-handler/routes/embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

class EmbeddingsRequest(BaseModel):
provider: str = Field(min_length=1, max_length=500, pattern=SAFE_STR_REGEX)
model: str = Field(min_length=1, max_length=500, pattern=SAFE_STR_REGEX)
model: str = Field(min_length=1, max_length=500, pattern=r"^[A-Za-z0-9-_. /]*$")
passages: List[Annotated[str, Field(min_length=1, max_length=MAX_STR_INPUT_LENGTH)]]
task: Optional[Task] = Task.STORE

Expand Down
11 changes: 8 additions & 3 deletions lib/chatbot-api/functions/api-handler/routes/workspaces.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Annotated, List, Optional
from common.constant import (
SAFE_SHORT_STR_VALIDATION,
SAFE_SHORT_STR_VALIDATION_OPTIONAL,
)
from common.validation import WorkspaceIdValidation
import genai_core.types
Expand All @@ -27,9 +28,13 @@ class CreateWorkspaceAuroraRequest(BaseModel):
kind: str = SAFE_SHORT_STR_VALIDATION
name: str = Field(min_length=1, max_length=100, pattern=name_regex)
embeddingsModelProvider: str = SAFE_SHORT_STR_VALIDATION
embeddingsModelName: str = SAFE_SHORT_STR_VALIDATION
crossEncoderModelProvider: Optional[str] = SAFE_SHORT_STR_VALIDATION
crossEncoderModelName: Optional[str] = SAFE_SHORT_STR_VALIDATION
embeddingsModelName: str = Field(
min_length=0, max_length=500, pattern=r"^[A-Za-z0-9-_. /]*$", default=None
)
crossEncoderModelProvider: Optional[str] = SAFE_SHORT_STR_VALIDATION_OPTIONAL
crossEncoderModelName: Optional[str] = Field(
min_length=0, max_length=500, pattern=r"^[A-Za-z0-9-_. /]*$", default=None
)
languages: List[Annotated[str, SAFE_SHORT_STR_VALIDATION]]
metric: str = SAFE_SHORT_STR_VALIDATION
index: bool
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,8 @@
from aws_lambda_powertools import Logger, Tracer
from aws_lambda_powertools.utilities.typing import LambdaContext
from aws_lambda_powertools.logging import correlation_paths
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, ValidationError


tracer = Tracer()
logger = Logger(log_uncaught_exceptions=True)
Expand All @@ -17,29 +18,36 @@
MAX_STR_INPUT_LENGTH = 1000000
SAFE_STR_REGEX = r"^[A-Za-z0-9-_. ]*$"
SAFE_SHORT_STR_VALIDATION = Field(min_length=0, max_length=500, pattern=SAFE_STR_REGEX)
SAFE_SHORT_STR_VALIDATION_OPTIONAL = Field(
min_length=0, max_length=500, pattern=SAFE_STR_REGEX, default=None
)


class ModelKwargsFieldValidation(BaseModel):
streaming: Optional[bool]
maxTokens: Optional[int] = Field(gt=0, lt=1000000)
temperature: Optional[float] = Field(ge=0, le=1)
topP: Optional[float] = Field(ge=0, le=1)
streaming: Optional[bool] = None
maxTokens: Optional[int] = Field(gt=0, lt=1000000, default=None)
temperature: Optional[float] = Field(ge=0, le=1, default=None)
topP: Optional[float] = Field(ge=0, le=1, default=None)


class FileFieldValidation(BaseModel):
provider: Optional[str] = SAFE_SHORT_STR_VALIDATION
key: Optional[str] = SAFE_SHORT_STR_VALIDATION
provider: Optional[str] = SAFE_SHORT_STR_VALIDATION_OPTIONAL
key: Optional[str] = SAFE_SHORT_STR_VALIDATION_OPTIONAL


class DataFieldValidation(BaseModel):
modelName: Optional[str] = SAFE_SHORT_STR_VALIDATION
provider: Optional[str] = SAFE_SHORT_STR_VALIDATION
sessionId: Optional[str] = SAFE_SHORT_STR_VALIDATION
workspaceId: Optional[str] = SAFE_SHORT_STR_VALIDATION
mode: Optional[str] = SAFE_SHORT_STR_VALIDATION
text: Optional[str] = Field(min_length=1, max_length=MAX_STR_INPUT_LENGTH)
modelName: Optional[str] = Field(
min_length=0, max_length=500, pattern=r"^[A-Za-z0-9-_. /:]*$", default=None
)
provider: Optional[str] = SAFE_SHORT_STR_VALIDATION_OPTIONAL
sessionId: Optional[str] = SAFE_SHORT_STR_VALIDATION_OPTIONAL
workspaceId: Optional[str] = SAFE_SHORT_STR_VALIDATION_OPTIONAL
mode: Optional[str] = SAFE_SHORT_STR_VALIDATION_OPTIONAL
text: Optional[str] = Field(
min_length=1, max_length=MAX_STR_INPUT_LENGTH, default=None
)
files: Optional[List[FileFieldValidation]]
modelKwargs: Optional[ModelKwargsFieldValidation]
modelKwargs: Optional[ModelKwargsFieldValidation] = None


class InputValidation(BaseModel):
Expand Down Expand Up @@ -67,11 +75,15 @@ def handler(event, context: LambdaContext):
"userId": event["identity"]["sub"],
"data": request.get("data", {}),
}
InputValidation(**message)

try:
InputValidation(**message)
response = sns.publish(TopicArn=TOPIC_ARN, Message=json.dumps(message))
return response
except ValidationError as e:
errors = e.errors(include_url=False, include_context=False, include_input=False)
logger.warning("Validation error", errors=errors)
raise ValueError(f"Invalid request. Details: {errors}")
except Exception as e:
# Do not return an unknown exception to the end user.
logger.exception(e)
Expand Down
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
pydantic==2.4.0
pydantic==2.9.2
aws_xray_sdk==2.14.0
4 changes: 4 additions & 0 deletions lib/rag-engines/aurora-pgvector/create-aurora-workspace.ts
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,10 @@ export class CreateAuroraWorkspace extends Construct {
? cdk.RemovalPolicy.RETAIN_ON_UPDATE_OR_DELETE
: cdk.RemovalPolicy.DESTROY,
retention: props.config.logRetention,
// Log group name should start with `/aws/vendedlogs/` to not exceed Cloudwatch Logs Resource Policy
// size limit.
// https://docs.aws.amazon.com/step-functions/latest/dg/bp-cwl.html
logGroupName: `/aws/vendedlogs/states/CreateAuroraWorkspace-${this.node.addr}`,
}
);

Expand Down
4 changes: 4 additions & 0 deletions lib/rag-engines/data-import/file-import-workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,10 @@ export class FileImportWorkflow extends Construct {
? cdk.RemovalPolicy.RETAIN_ON_UPDATE_OR_DELETE
: cdk.RemovalPolicy.DESTROY,
retention: props.config.logRetention,
// Log group name should start with `/aws/vendedlogs/` to not exceed Cloudwatch Logs Resource Policy
// size limit.
// https://docs.aws.amazon.com/step-functions/latest/dg/bp-cwl.html
logGroupName: `/aws/vendedlogs/states/FileImportStateMachine-${this.node.addr}`,
});

const workflow = setProcessing.next(fileImportJob).next(setProcessed);
Expand Down
4 changes: 4 additions & 0 deletions lib/rag-engines/data-import/website-crawling-workflow.ts
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,10 @@ export class WebsiteCrawlingWorkflow extends Construct {
? cdk.RemovalPolicy.RETAIN_ON_UPDATE_OR_DELETE
: cdk.RemovalPolicy.DESTROY,
retention: props.config.logRetention,
// Log group name should start with `/aws/vendedlogs/` to not exceed Cloudwatch Logs Resource Policy
// size limit.
// https://docs.aws.amazon.com/step-functions/latest/dg/bp-cwl.html
logGroupName: `/aws/vendedlogs/states/WebsiteCrawling-${this.node.addr}`,
});

const workflow = setProcessing.next(webCrawlerJob).next(setProcessed);
Expand Down
4 changes: 4 additions & 0 deletions lib/rag-engines/kendra-retrieval/create-kendra-workspace.ts
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,10 @@ export class CreateKendraWorkspace extends Construct {
? cdk.RemovalPolicy.RETAIN_ON_UPDATE_OR_DELETE
: cdk.RemovalPolicy.DESTROY,
retention: props.config.logRetention,
// Log group name should start with `/aws/vendedlogs/` to not exceed Cloudwatch Logs Resource Policy
// size limit.
// https://docs.aws.amazon.com/step-functions/latest/dg/bp-cwl.html
logGroupName: `/aws/vendedlogs/states/CreateKendraWorkspace-${this.node.addr}`,
}
);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,6 +151,10 @@ export class CreateOpenSearchWorkspace extends Construct {
? cdk.RemovalPolicy.RETAIN_ON_UPDATE_OR_DELETE
: cdk.RemovalPolicy.DESTROY,
retention: props.config.logRetention,
// Log group name should start with `/aws/vendedlogs/` to not exceed Cloudwatch Logs Resource Policy
// size limit.
// https://docs.aws.amazon.com/step-functions/latest/dg/bp-cwl.html
logGroupName: `/aws/vendedlogs/states/CreateOpenSearchWorkspace-${this.node.addr}`,
}
);

Expand Down
4 changes: 4 additions & 0 deletions lib/rag-engines/workspaces/delete-document.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ export class DeleteDocument extends Construct {
? cdk.RemovalPolicy.RETAIN_ON_UPDATE_OR_DELETE
: cdk.RemovalPolicy.DESTROY,
retention: props.config.logRetention,
// Log group name should start with `/aws/vendedlogs/` to not exceed Cloudwatch Logs Resource Policy
// size limit.
// https://docs.aws.amazon.com/step-functions/latest/dg/bp-cwl.html
logGroupName: `/aws/vendedlogs/states/DeleteWorkspace-${this.node.addr}`,
});

const stateMachine = new sfn.StateMachine(this, "DeleteDocument", {
Expand Down
4 changes: 4 additions & 0 deletions lib/rag-engines/workspaces/delete-workspace.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,10 @@ export class DeleteWorkspace extends Construct {
? cdk.RemovalPolicy.RETAIN_ON_UPDATE_OR_DELETE
: cdk.RemovalPolicy.DESTROY,
retention: props.config.logRetention,
// Log group name should start with `/aws/vendedlogs/` to not exceed Cloudwatch Logs Resource Policy
// size limit.
// https://docs.aws.amazon.com/step-functions/latest/dg/bp-cwl.html
logGroupName: `/aws/vendedlogs/states/DeleteWorkspace-${this.node.addr}`,
});

const stateMachine = new sfn.StateMachine(this, "DeleteWorkspace", {
Expand Down
5 changes: 2 additions & 3 deletions lib/shared/file-import-batch-job/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,11 @@ numpy==1.26.0
cfnresponse==1.1.2
aws_requests_auth==0.4.3
requests-aws4auth==1.2.3
langchain==0.2.14
langchain-community==0.2.12
langchain==0.3.5
langchain-community==0.3.3
opensearch-py==2.3.1
psycopg2-binary==2.9.7
pgvector==0.2.2
pydantic==2.4.0
urllib3<2
openai==1.47.0
beautifulsoup4==4.12.2
Expand Down
8 changes: 5 additions & 3 deletions lib/shared/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ export class Shared extends Construct {
this.kmsKeyAlias = props.config.prefix + "genaichatbot-shared-key";
this.queueKmsKeyAlias =
props.config.prefix + "genaichatbot-queue-shared-key";
const powerToolsLayerVersion = "46";
const powerToolsLayerVersion = "2";

this.defaultEnvironmentVariables = {
POWERTOOLS_DEV: "false",
Expand Down Expand Up @@ -258,10 +258,12 @@ export class Shared extends Construct {
stringValue: JSON.stringify(props.config),
});

//https://docs.powertools.aws.dev/lambda/python/3.2.0/
const pythonVersion = pythonRuntime.name.replace(".", "");
const powerToolsArn =
lambdaArchitecture === lambda.Architecture.X86_64
? `arn:${cdk.Aws.PARTITION}:lambda:${cdk.Aws.REGION}:017000801446:layer:AWSLambdaPowertoolsPythonV2:${powerToolsLayerVersion}`
: `arn:${cdk.Aws.PARTITION}:lambda:${cdk.Aws.REGION}:017000801446:layer:AWSLambdaPowertoolsPythonV2-Arm64:${powerToolsLayerVersion}`;
? `arn:${cdk.Aws.PARTITION}:lambda:${cdk.Aws.REGION}:017000801446:layer:AWSLambdaPowertoolsPythonV3-${pythonVersion}-x86_64:${powerToolsLayerVersion}`
: `arn:${cdk.Aws.PARTITION}:lambda:${cdk.Aws.REGION}:017000801446:layer:AWSLambdaPowertoolsPythonV3-${pythonVersion}-arm64:${powerToolsLayerVersion}`;

const powerToolsLayer = lambda.LayerVersion.fromLayerVersionArn(
this,
Expand Down
12 changes: 6 additions & 6 deletions lib/shared/layers/common/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,15 +3,15 @@ numpy==1.26.0
cfnresponse==1.1.2
aws_requests_auth==0.4.3
requests-aws4auth==1.2.3
langchain==0.2.14
langchain-community==0.2.12
langchain-aws==0.1.17
langchain-openai==0.1.25
openai==1.47.0
langchain==0.3.5
langchain-core==0.3.13
langchain-community==0.3.3
langchain-aws==0.2.4
langchain-openai==0.2.4
langchain-text-splitters==0.3.1
opensearch-py==2.4.2
psycopg2-binary==2.9.7
pgvector==0.2.2
pydantic==2.4.0
urllib3<2
beautifulsoup4==4.12.2
requests==2.32.0
Expand Down
2 changes: 1 addition & 1 deletion lib/shared/layers/python-sdk/python/genai_core/chunks.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import genai_core.opensearch.chunks
from genai_core.types import CommonError, Task
from typing import List, Optional
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_text_splitters import RecursiveCharacterTextSplitter

PROCESSING_BUCKET_NAME = os.environ.get("PROCESSING_BUCKET_NAME", "")
s3 = boto3.resource("s3")
Expand Down
5 changes: 3 additions & 2 deletions lib/shared/utils.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,14 @@ export abstract class Utils {
}
}

static getDefaultEmbeddingsModel(config: SystemConfig): string {
static getDefaultEmbeddingsModel(config: SystemConfig): string | undefined {
const defaultModel = config.rag.embeddingsModels.find(
(model) => model.default === true
);

if (!defaultModel) {
throw new Error("No default embeddings model found");
// No default embdeding is set in the config when Aurora or Opensearch are not used.
return undefined;
}

return `${defaultModel.provider}::${defaultModel.dimensions}::${defaultModel.name}`;
Expand Down
3 changes: 1 addition & 2 deletions lib/shared/web-crawler-batch-job/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,10 @@ numpy==1.26.0
cfnresponse==1.1.2
aws_requests_auth==0.4.3
requests-aws4auth==1.2.3
langchain==0.2.14
langchain==0.3.5
opensearch-py==2.3.1
psycopg2-binary==2.9.7
pgvector==0.2.2
pydantic==2.4.0
urllib3<2
openai==0.28.0
beautifulsoup4==4.12.2
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -72,10 +72,6 @@ export default function Embeddings() {
validate: (form) => {
const errors: Record<string, string | string[]> = {};

if (!form.embeddingsModel) {
errors.embeddingsModel = "Embeddings model is required";
}

for (let i = 0; i < form.input.length; i++) {
const input = form.input[i];
if (input.trim().length === 0) {
Expand Down
2 changes: 1 addition & 1 deletion lib/user-interface/react-app/vite.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ export default defineConfig({
define: {
"process.env": {},
// Prevents replacing global in the import strings.
global: "global",
global: isDev ? {} : "global",
},
plugins: [
isDev && {
Expand Down
Loading

0 comments on commit bd520f6

Please sign in to comment.