Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
211 changes: 209 additions & 2 deletions crates/fluss/src/metadata/datatype.rs
Original file line number Diff line number Diff line change
Expand Up @@ -933,8 +933,12 @@ impl RowType {
pub fn project_with_field_names(&self, field_names: &[String]) -> Result<RowType> {
let indices: Vec<usize> = field_names
.iter()
.filter_map(|pk| self.get_field_index(pk))
.collect();
.map(|name| {
self.get_field_index(name).ok_or_else(|| IllegalArgument {
message: format!("Field '{}' does not exist in the row type", name),
})
})
.collect::<Result<Vec<_>>>()?;

self.project(indices.as_slice())
}
Expand Down Expand Up @@ -1405,6 +1409,10 @@ fn test_deeply_nested_types() {
assert_eq!(nested.to_string(), "ARRAY<MAP<STRING, ROW<x INT, y INT>>>");
}

// ============================================================================
// DecimalType validation tests
// ============================================================================

#[test]
fn test_decimal_invalid_precision() {
// DecimalType::with_nullable should return an error for invalid precision
Expand All @@ -1431,6 +1439,76 @@ fn test_decimal_invalid_scale() {
);
}

// ============================================================================
// DecimalType validation tests - edge cases
// ============================================================================

#[test]
fn test_decimal_valid_precision_and_scale() {
// Valid: precision=10, scale=2
let result = DecimalType::with_nullable(true, 10, 2);
assert!(result.is_ok());
let decimal = result.unwrap();
assert_eq!(decimal.precision(), 10);
assert_eq!(decimal.scale(), 2);
// Nullable: should NOT contain "NOT NULL"
assert!(!decimal.to_string().contains("NOT NULL"));

// Valid: precision=38, scale=0
let result = DecimalType::with_nullable(true, 38, 0);
assert!(result.is_ok());
let decimal = result.unwrap();
assert_eq!(decimal.precision(), 38);
assert_eq!(decimal.scale(), 0);

// Valid: precision=1, scale=0
let result = DecimalType::with_nullable(false, 1, 0);
assert!(result.is_ok());
let decimal = result.unwrap();
assert_eq!(decimal.precision(), 1);
assert_eq!(decimal.scale(), 0);
// Non-nullable: should contain "NOT NULL"
assert!(decimal.to_string().contains("NOT NULL"));
}

#[test]
fn test_decimal_invalid_precision_zero() {
// Invalid: precision=0 (edge case not covered by existing tests)
let result = DecimalType::with_nullable(true, 0, 0);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Decimal precision must be between 1 and 38")
);
}

#[test]
fn test_decimal_scale_equals_precision_boundary() {
// Boundary: precision=10, scale=10 (scale == precision is valid)
let result = DecimalType::with_nullable(true, 10, 10);
assert!(result.is_ok());
let decimal = result.unwrap();
assert_eq!(decimal.precision(), 10);
assert_eq!(decimal.scale(), 10);
}

// ============================================================================
// TimeType validation tests
// ============================================================================

#[test]
fn test_time_valid_precision() {
// Test all valid precision values 0 through 9
for precision in 0..=9 {
let result = TimeType::with_nullable(true, precision);
assert!(result.is_ok(), "precision {} should be valid", precision);
let time = result.unwrap();
assert_eq!(time.precision(), precision);
}
}

#[test]
fn test_time_invalid_precision() {
// TimeType::with_nullable should return an error for invalid precision
Expand All @@ -1444,6 +1522,21 @@ fn test_time_invalid_precision() {
);
}

// ============================================================================
// TimestampType validation tests
// ============================================================================

#[test]
fn test_timestamp_valid_precision() {
// Test all valid precision values 0 through 9
for precision in 0..=9 {
let result = TimestampType::with_nullable(true, precision);
assert!(result.is_ok(), "precision {} should be valid", precision);
let timestamp_type = result.unwrap();
assert_eq!(timestamp_type.precision(), precision);
}
}

#[test]
fn test_timestamp_invalid_precision() {
// TimestampType::with_nullable should return an error for invalid precision
Expand All @@ -1469,3 +1562,117 @@ fn test_timestamp_ltz_invalid_precision() {
.contains("Timestamp with local time zone precision must be between 0 and 9")
);
}

// ============================================================================
// RowType projection tests
// ============================================================================

#[test]
fn test_row_type_project_valid_indices() {
// Create a 3-column row type
let row_type = RowType::with_data_types_and_field_names(
vec![DataTypes::int(), DataTypes::string(), DataTypes::bigint()],
vec!["id", "name", "age"],
);

// Valid projection by indices: [0, 2]
let projected = row_type.project(&[0, 2]).unwrap();
assert_eq!(projected.fields().len(), 2);
assert_eq!(projected.fields()[0].name, "id");
assert_eq!(projected.fields()[1].name, "age");
}

#[test]
fn test_row_type_project_empty_indices() {
// Create a 3-column row type
let row_type = RowType::with_data_types_and_field_names(
vec![DataTypes::int(), DataTypes::string(), DataTypes::bigint()],
vec!["id", "name", "age"],
);

// Projection with an empty indices array should yield an empty RowType
let projected = row_type.project(&[]).unwrap();
assert_eq!(projected.fields().len(), 0);
}

#[test]
fn test_row_type_project_with_field_names_valid() {
// Create a 3-column row type
let row_type = RowType::with_data_types_and_field_names(
vec![DataTypes::int(), DataTypes::string(), DataTypes::bigint()],
vec!["id", "name", "age"],
);

// Valid projection by names: ["id", "name"]
let projected = row_type
.project_with_field_names(&["id".to_string(), "name".to_string()])
.unwrap();
assert_eq!(projected.fields().len(), 2);
assert_eq!(projected.fields()[0].name, "id");
assert_eq!(projected.fields()[1].name, "name");
}

#[test]
fn test_row_type_project_index_out_of_bounds() {
// Create a 3-column row type
let row_type = RowType::with_data_types_and_field_names(
vec![DataTypes::int(), DataTypes::string(), DataTypes::bigint()],
vec!["id", "name", "age"],
);

// Error: index out of bounds
let result = row_type.project(&[0, 5]);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("invalid field position: 5")
);
}

#[test]
fn test_row_type_project_with_field_names_nonexistent() {
// Create a 3-column row type
let row_type = RowType::with_data_types_and_field_names(
vec![DataTypes::int(), DataTypes::string(), DataTypes::bigint()],
vec!["id", "name", "age"],
);

// Error: non-existent field name should throw exception
let result = row_type.project_with_field_names(&["nonexistent".to_string()]);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Field 'nonexistent' does not exist in the row type")
);

// Mixed existing and non-existing: should also error on the first non-existent field
let result = row_type.project_with_field_names(&["id".to_string(), "nonexistent".to_string()]);
assert!(result.is_err());
assert!(
result
.unwrap_err()
.to_string()
.contains("Field 'nonexistent' does not exist in the row type")
);
}

#[test]
fn test_row_type_project_duplicate_indices() {
// Create a 3-column row type
let row_type = RowType::with_data_types_and_field_names(
vec![DataTypes::int(), DataTypes::string(), DataTypes::bigint()],
vec!["id", "name", "age"],
);

// Projection with duplicate indices: [0, 0, 1]
// This documents the expected behavior - duplicates are allowed
let projected = row_type.project(&[0, 0, 1]).unwrap();
assert_eq!(projected.fields().len(), 3);
assert_eq!(projected.fields()[0].name, "id");
assert_eq!(projected.fields()[1].name, "id");
assert_eq!(projected.fields()[2].name, "name");
}
Loading