From e3efe0cf2330dabd620e2e8749e03dbd66e4b0b0 Mon Sep 17 00:00:00 2001 From: Marcel Canu Date: Fri, 7 Mar 2025 14:29:54 -0300 Subject: [PATCH] Validating connection in serializer for both import and export. Validating bucket name in serializer --- label_studio/io_storages/s3/serializers.py | 39 +++++++++++++--------- 1 file changed, 23 insertions(+), 16 deletions(-) diff --git a/label_studio/io_storages/s3/serializers.py b/label_studio/io_storages/s3/serializers.py index 3b46cd0909f8..f97629e7cd1a 100644 --- a/label_studio/io_storages/s3/serializers.py +++ b/label_studio/io_storages/s3/serializers.py @@ -3,29 +3,33 @@ import os from botocore.exceptions import ClientError, ParamValidationError +from botocore.handlers import validate_bucket_name from io_storages.s3.models import S3ExportStorage, S3ImportStorage from io_storages.serializers import ExportStorageSerializer, ImportStorageSerializer from rest_framework import serializers from rest_framework.exceptions import ValidationError -class S3ImportStorageSerializer(ImportStorageSerializer): - type = serializers.ReadOnlyField(default=os.path.basename(os.path.dirname(__file__))) - presign = serializers.BooleanField(required=False, default=True) +class S3StorageSerializerMixin: secure_fields = ['aws_access_key_id', 'aws_secret_access_key'] - class Meta: - model = S3ImportStorage - fields = '__all__' - def to_representation(self, instance): result = super().to_representation(instance) - for attr in S3ImportStorageSerializer.secure_fields: + for attr in self.secure_fields: result.pop(attr) return result + def validate_bucket(self, value): + if not value: + return value + try: + validate_bucket_name({'Bucket': value}) + except ParamValidationError as exc: + raise ValidationError(exc.kwargs['report']) from exc + return value + def validate(self, data): - data = super(S3ImportStorageSerializer, self).validate(data) + data = super().validate(data) if not data.get('bucket', None): return data @@ -36,7 +40,7 @@ def validate(self, data): else: if 'id' in self.initial_data: storage_object = self.Meta.model.objects.get(id=self.initial_data['id']) - for attr in S3ImportStorageSerializer.secure_fields: + for attr in self.secure_fields: data[attr] = data.get(attr) or getattr(storage_object, attr) storage = self.Meta.model(**data) try: @@ -63,14 +67,17 @@ def validate(self, data): return data -class S3ExportStorageSerializer(ExportStorageSerializer): +class S3ImportStorageSerializer(S3StorageSerializerMixin, ImportStorageSerializer): type = serializers.ReadOnlyField(default=os.path.basename(os.path.dirname(__file__))) + presign = serializers.BooleanField(required=False, default=True) - def to_representation(self, instance): - result = super().to_representation(instance) - result.pop('aws_access_key_id') - result.pop('aws_secret_access_key') - return result + class Meta: + model = S3ImportStorage + fields = '__all__' + + +class S3ExportStorageSerializer(S3StorageSerializerMixin, ExportStorageSerializer): + type = serializers.ReadOnlyField(default=os.path.basename(os.path.dirname(__file__))) class Meta: model = S3ExportStorage