Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 29 additions & 16 deletions sagemaker-train/src/sagemaker/ai_registry/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@

import pandas as pd

from sagemaker.ai_registry.dataset_format_detector import DatasetFormatDetector
from sagemaker.ai_registry.air_hub import AIRHub
from sagemaker.ai_registry.air_utils import _determine_new_version, _get_default_bucket
from sagemaker.ai_registry.air_constants import (
Expand Down Expand Up @@ -179,6 +180,21 @@ def _validate_dataset_file(cls, file_path: str) -> None:
max_size_mb = DATASET_MAX_FILE_SIZE_BYTES / (1024 * 1024)
raise ValueError(f"File size {file_size_mb:.2f} MB exceeds maximum allowed size of {max_size_mb:.0f} MB")

@classmethod
def _validate_dataset_format(cls, file_path: str) -> None:
"""Validate dataset format using DatasetFormatDetector.

Args:
file_path: Path to the dataset file (local path)

Raises:
ValueError: If dataset format cannot be detected
"""
detector = DatasetFormatDetector()
format_name = detector.validate_dataset(file_path)
if format_name is False:
raise ValueError(f"Unable to detect format for {file_path}. Please provide a valid dataset file.")

@classmethod
@_telemetry_emitter(feature=Feature.MODEL_CUSTOMIZATION, func_name="DataSet.get")
def get(cls, name: str, sagemaker_session=None) -> "DataSet":
Expand Down Expand Up @@ -257,28 +273,25 @@ def create(
s3_prefix = s3_key # Use full path including filename
method = DataSetMethod.GENERATED

# Download and validate if customization technique is provided
if customization_technique:
with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(s3_key)[1]
) as tmp_file:
local_path = tmp_file.name

try:
AIRHub.download_from_s3(source, local_path)
validate_dataset(local_path, customization_technique.value)
finally:
if os.path.exists(local_path):
os.remove(local_path)
# Download and validate format
with tempfile.NamedTemporaryFile(
delete=False, suffix=os.path.splitext(s3_key)[1]
) as tmp_file:
local_path = tmp_file.name

try:
AIRHub.download_from_s3(source, local_path)
cls._validate_dataset_format(local_path)
finally:
if os.path.exists(local_path):
os.remove(local_path)
else:
# Local file - upload to S3
bucket_name = _get_default_bucket()
s3_prefix = _get_default_s3_prefix(name)
method = DataSetMethod.UPLOADED

if customization_technique:
validate_dataset(source, customization_technique.value)

cls._validate_dataset_format(source)
AIRHub.upload_to_s3(bucket_name, s3_prefix, source)

# Create hub content document
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import json
from typing import Dict, Any, Optional
from pathlib import Path


class DatasetFormatDetector:
"""Utility class for detecting dataset formats."""

# Schema directory
SCHEMA_DIR = Path(__file__).parent / "schemas"

@staticmethod
def _load_schema(format_name: str) -> Dict[str, Any]:
"""Load JSON schema for a format."""
schema_path = DatasetFormatDetector.SCHEMA_DIR / f"{format_name}.json"
if schema_path.exists():
with open(schema_path) as f:
return json.load(f)
return {}

@staticmethod
def validate_dataset(file_path: str) -> bool:
"""
Validate if the dataset adheres to any known format.

Args:
file_path: Path to the JSONL file

Returns:
True if dataset is valid according to any known format, False otherwise
"""
import jsonschema

# Schema-based formats
schema_formats = [
"dpo", "converse", "hf_preference", "hf_prompt_completion",
"verl", "openai_chat", "genqa"
]

try:
with open(file_path, 'r') as f:
for line in f:
line = line.strip()
if line:
data = json.loads(line)

# Try schema validation first
for format_name in schema_formats:
schema = DatasetFormatDetector._load_schema(format_name)
if schema:
try:
jsonschema.validate(instance=data, schema=schema)
return True
except jsonschema.exceptions.ValidationError:
continue

# Check for RFT-style format (messages + additional fields)
if DatasetFormatDetector._is_rft_format(data):
return True
break
return False
except (json.JSONDecodeError, FileNotFoundError, IOError):
return False

@staticmethod
def _is_rft_format(data: Dict[str, Any]) -> bool:
"""Check if data matches RFT format pattern."""
if not isinstance(data, dict) or "messages" not in data:
return False

messages = data["messages"]
if not isinstance(messages, list) or not messages:
return False

# Check message structure
for msg in messages:
if not isinstance(msg, dict):
return False
if "role" not in msg or "content" not in msg:
return False
if not isinstance(msg["role"], str) or not isinstance(msg["content"], str):
return False

return True
Loading
Loading