diff --git a/src/codeflare_sdk/common/utils/test_validation.py b/src/codeflare_sdk/common/utils/test_validation.py new file mode 100644 index 00000000..20416d00 --- /dev/null +++ b/src/codeflare_sdk/common/utils/test_validation.py @@ -0,0 +1,224 @@ +# Copyright 2022-2025 IBM, Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from codeflare_sdk.common.utils.validation import ( + extract_ray_version_from_image, + validate_ray_version_compatibility, +) +from codeflare_sdk.common.utils.constants import RAY_VERSION + + +class TestRayVersionDetection: + """Test Ray version detection from container image names.""" + + def test_extract_ray_version_standard_format(self): + """Test extraction from standard Ray image formats.""" + # Standard format + assert extract_ray_version_from_image("ray:2.47.1") == "2.47.1" + assert extract_ray_version_from_image("ray:2.46.0") == "2.46.0" + assert extract_ray_version_from_image("ray:1.13.0") == "1.13.0" + + def test_extract_ray_version_with_registry(self): + """Test extraction from images with registry prefixes.""" + assert extract_ray_version_from_image("quay.io/ray:2.47.1") == "2.47.1" + assert ( + extract_ray_version_from_image("docker.io/rayproject/ray:2.47.1") + == "2.47.1" + ) + assert ( + extract_ray_version_from_image("gcr.io/my-project/ray:2.47.1") == "2.47.1" + ) + + def test_extract_ray_version_with_suffixes(self): + """Test extraction from images with version suffixes.""" + assert ( + extract_ray_version_from_image("quay.io/modh/ray:2.47.1-py311-cu121") + == "2.47.1" + ) + assert extract_ray_version_from_image("ray:2.47.1-py311") == "2.47.1" + assert extract_ray_version_from_image("ray:2.47.1-gpu") == "2.47.1" + assert extract_ray_version_from_image("ray:2.47.1-rocm62") == "2.47.1" + + def test_extract_ray_version_complex_registry_paths(self): + """Test extraction from complex registry paths.""" + assert ( + extract_ray_version_from_image("quay.io/modh/ray:2.47.1-py311-cu121") + == "2.47.1" + ) + assert ( + extract_ray_version_from_image("registry.company.com/team/ray:2.47.1") + == "2.47.1" + ) + + def test_extract_ray_version_no_version_found(self): + """Test cases where no version can be extracted.""" + # SHA-based tags + assert ( + extract_ray_version_from_image( + "quay.io/modh/ray@sha256:6d076aeb38ab3c34a6a2ef0f58dc667089aa15826fa08a73273c629333e12f1e" + ) + is None + ) + + # Non-semantic versions + assert extract_ray_version_from_image("ray:latest") is None + assert extract_ray_version_from_image("ray:nightly") is None + assert ( + extract_ray_version_from_image("ray:v2.47") is None + ) # Missing patch version + + # Non-Ray images + assert extract_ray_version_from_image("python:3.11") is None + assert extract_ray_version_from_image("ubuntu:20.04") is None + + # Empty or None + assert extract_ray_version_from_image("") is None + assert extract_ray_version_from_image(None) is None + + def test_extract_ray_version_edge_cases(self): + """Test edge cases for version extraction.""" + # Version with 'v' prefix should not match our pattern + assert extract_ray_version_from_image("ray:v2.47.1") is None + + # Multiple version-like patterns - should match the first valid one + assert ( + extract_ray_version_from_image("registry/ray:2.47.1-based-on-1.0.0") + == "2.47.1" + ) + + +class TestRayVersionValidation: + """Test Ray version compatibility validation.""" + + def test_validate_compatible_versions(self): + """Test validation with compatible Ray versions.""" + # Exact match + is_compatible, is_warning, message = validate_ray_version_compatibility( + f"ray:{RAY_VERSION}" + ) + assert is_compatible is True + assert is_warning is False + assert "Ray versions match" in message + + # With registry and suffixes + is_compatible, is_warning, message = validate_ray_version_compatibility( + f"quay.io/modh/ray:{RAY_VERSION}-py311-cu121" + ) + assert is_compatible is True + assert is_warning is False + assert "Ray versions match" in message + + def test_validate_incompatible_versions(self): + """Test validation with incompatible Ray versions.""" + # Different version + is_compatible, is_warning, message = validate_ray_version_compatibility( + "ray:2.46.0" + ) + assert is_compatible is False + assert is_warning is False + assert "Ray version mismatch detected" in message + assert "CodeFlare SDK uses Ray" in message + assert "runtime image uses Ray" in message + + # Older version + is_compatible, is_warning, message = validate_ray_version_compatibility( + "ray:1.13.0" + ) + assert is_compatible is False + assert is_warning is False + assert "Ray version mismatch detected" in message + + def test_validate_empty_image(self): + """Test validation with no custom image (should use default).""" + # Empty string + is_compatible, is_warning, message = validate_ray_version_compatibility("") + assert is_compatible is True + assert is_warning is False + assert "Using default Ray image compatible with SDK" in message + + # None + is_compatible, is_warning, message = validate_ray_version_compatibility(None) + assert is_compatible is True + assert is_warning is False + assert "Using default Ray image compatible with SDK" in message + + def test_validate_unknown_version(self): + """Test validation when version cannot be determined.""" + # SHA-based image + is_compatible, is_warning, message = validate_ray_version_compatibility( + "quay.io/modh/ray@sha256:6d076aeb38ab3c34a6a2ef0f58dc667089aa15826fa08a73273c629333e12f1e" + ) + assert is_compatible is True + assert is_warning is True + assert "Cannot determine Ray version" in message + + # Custom image without version + is_compatible, is_warning, message = validate_ray_version_compatibility( + "my-custom-ray:latest" + ) + assert is_compatible is True + assert is_warning is True + assert "Cannot determine Ray version" in message + + def test_validate_custom_sdk_version(self): + """Test validation with custom SDK version.""" + # Compatible with custom SDK version + is_compatible, is_warning, message = validate_ray_version_compatibility( + "ray:2.46.0", "2.46.0" + ) + assert is_compatible is True + assert is_warning is False + assert "Ray versions match" in message + + # Incompatible with custom SDK version + is_compatible, is_warning, message = validate_ray_version_compatibility( + "ray:2.47.1", "2.46.0" + ) + assert is_compatible is False + assert is_warning is False + assert "CodeFlare SDK uses Ray 2.46.0" in message + assert "runtime image uses Ray 2.47.1" in message + + def test_validate_message_content(self): + """Test that validation messages contain expected guidance.""" + # Mismatch message should contain helpful guidance + is_compatible, is_warning, message = validate_ray_version_compatibility( + "ray:2.46.0" + ) + assert is_compatible is False + assert is_warning is False + assert "compatibility issues" in message.lower() + assert "unexpected behavior" in message.lower() + assert "please use a runtime image" in message.lower() + assert "update your sdk version" in message.lower() + + def test_semantic_version_comparison(self): + """Test that semantic version comparison works correctly.""" + # Test that 2.10.0 > 2.9.1 (would fail with string comparison) + is_compatible, is_warning, message = validate_ray_version_compatibility( + "ray:2.10.0", "2.9.1" + ) + assert is_compatible is False + assert is_warning is False + assert "CodeFlare SDK uses Ray 2.9.1" in message + assert "runtime image uses Ray 2.10.0" in message + + # Test that 2.9.1 < 2.10.0 (would fail with string comparison) + is_compatible, is_warning, message = validate_ray_version_compatibility( + "ray:2.9.1", "2.10.0" + ) + assert is_compatible is False + assert is_warning is False + assert "CodeFlare SDK uses Ray 2.10.0" in message + assert "runtime image uses Ray 2.9.1" in message diff --git a/src/codeflare_sdk/common/utils/validation.py b/src/codeflare_sdk/common/utils/validation.py new file mode 100644 index 00000000..ec749f7c --- /dev/null +++ b/src/codeflare_sdk/common/utils/validation.py @@ -0,0 +1,134 @@ +# Copyright 2022-2025 IBM, Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +Validation utilities for the CodeFlare SDK. + +This module contains validation functions used across the SDK for ensuring +configuration compatibility and correctness. +""" + +import logging +import re +from typing import Optional, Tuple +from packaging.version import Version, InvalidVersion +from .constants import RAY_VERSION + +logger = logging.getLogger(__name__) + + +def extract_ray_version_from_image(image_name: str) -> Optional[str]: + """ + Extract Ray version from a container image name. + + Supports various image naming patterns: + - quay.io/modh/ray:2.47.1-py311-cu121 + - ray:2.47.1 + - some-registry/ray:2.47.1-py311 + - quay.io/modh/ray@sha256:... (falls back to None) + + Args: + image_name: The container image name/tag + + Returns: + The extracted Ray version, or None if not found + """ + if not image_name: + return None + + # Pattern to match semantic version after ray: or ray/ + # Looks for patterns like ray:2.47.1, ray:2.47.1-py311, etc. + patterns = [ + r"ray:(\d+\.\d+\.\d+)", # ray:2.47.1 + r"ray/[^:]*:(\d+\.\d+\.\d+)", # registry/ray:2.47.1 + r"/ray:(\d+\.\d+\.\d+)", # any-registry/ray:2.47.1 + ] + + for pattern in patterns: + match = re.search(pattern, image_name) + if match: + return match.group(1) + + # If we can't extract version, return None to indicate unknown + return None + + +def validate_ray_version_compatibility( + image_name: str, sdk_ray_version: str = RAY_VERSION +) -> Tuple[bool, bool, str]: + """ + Validate that the Ray version in the runtime image matches the SDK's Ray version. + + Args: + image_name: The container image name/tag + sdk_ray_version: The Ray version used by the CodeFlare SDK + + Returns: + tuple: (is_compatible, is_warning, message) + - is_compatible: True if versions match or cannot be determined, False if mismatch + - is_warning: True if this is a warning (non-fatal), False otherwise + - message: Descriptive message about the validation result + """ + if not image_name: + # No custom image specified, will use default - this is compatible + logger.debug("Using default Ray image compatible with SDK") + return True, False, "Using default Ray image compatible with SDK" + + image_ray_version = extract_ray_version_from_image(image_name) + + if image_ray_version is None: + # Cannot determine version from image name, issue a warning but allow + return ( + True, + True, + f"Cannot determine Ray version from image '{image_name}'. Please ensure it's compatible with Ray {sdk_ray_version}", + ) + + # Use semantic version comparison for robust version checking + try: + sdk_version = Version(sdk_ray_version) + image_version = Version(image_ray_version) + + if image_version != sdk_version: + # Version mismatch detected + message = ( + f"Ray version mismatch detected!\n" + f"CodeFlare SDK uses Ray {sdk_ray_version}, but runtime image uses Ray {image_ray_version}.\n" + f"This mismatch can cause compatibility issues and unexpected behavior.\n" + f"Please use a runtime image with Ray {sdk_ray_version} or update your SDK version." + ) + return False, False, message + except InvalidVersion as e: + # If version parsing fails, fall back to string comparison with a warning + logger.warning( + f"Failed to parse version for comparison ({e}), falling back to string comparison" + ) + if image_ray_version != sdk_ray_version: + message = ( + f"Ray version mismatch detected!\n" + f"CodeFlare SDK uses Ray {sdk_ray_version}, but runtime image uses Ray {image_ray_version}.\n" + f"This mismatch can cause compatibility issues and unexpected behavior.\n" + f"Please use a runtime image with Ray {sdk_ray_version} or update your SDK version." + ) + return False, False, message + + # Versions match + logger.debug( + f"Ray version validation successful: SDK and runtime image both use Ray {sdk_ray_version}" + ) + return ( + True, + False, + f"Ray versions match: SDK and runtime image both use Ray {sdk_ray_version}", + ) diff --git a/src/codeflare_sdk/ray/cluster/test_config.py b/src/codeflare_sdk/ray/cluster/test_config.py index 6f002df1..dab58ba2 100644 --- a/src/codeflare_sdk/ray/cluster/test_config.py +++ b/src/codeflare_sdk/ray/cluster/test_config.py @@ -1,4 +1,4 @@ -# Copyright 2024 IBM, Red Hat +# Copyright 2022-2025 IBM, Red Hat # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/src/codeflare_sdk/ray/rayjobs/rayjob.py b/src/codeflare_sdk/ray/rayjobs/rayjob.py index a1577d91..6230a0e1 100644 --- a/src/codeflare_sdk/ray/rayjobs/rayjob.py +++ b/src/codeflare_sdk/ray/rayjobs/rayjob.py @@ -1,4 +1,4 @@ -# Copyright 2025 IBM, Red Hat +# Copyright 2022-2025 IBM, Red Hat # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -17,12 +17,14 @@ """ import logging +import warnings from typing import Dict, Any, Optional, Tuple from python_client.kuberay_job_api import RayjobApi from codeflare_sdk.ray.rayjobs.config import ManagedClusterConfig from ...common.utils import get_current_namespace +from ...common.utils.validation import validate_ray_version_compatibility from .status import ( RayJobDeploymentStatus, @@ -149,6 +151,9 @@ def submit(self) -> str: if not self.entrypoint: raise ValueError("entrypoint must be provided to submit a RayJob") + # Validate Ray version compatibility for both cluster_config and runtime_env + self._validate_ray_version_compatibility() + # Build the RayJob custom resource rayjob_cr = self._build_rayjob_cr() @@ -213,6 +218,42 @@ def _build_rayjob_cr(self) -> Dict[str, Any]: return rayjob_cr + def _validate_ray_version_compatibility(self): + """ + Validate Ray version compatibility for cluster_config image. + Raises ValueError if there is a version mismatch. + """ + # Validate cluster_config image if creating new cluster + if self._cluster_config is not None: + self._validate_cluster_config_image() + + def _validate_cluster_config_image(self): + """ + Validate that the Ray version in cluster_config image matches the SDK's Ray version. + """ + if not hasattr(self._cluster_config, "image"): + logger.debug( + "No image attribute found in cluster config, skipping validation" + ) + return + + image = self._cluster_config.image + if not image: + logger.debug("Cluster config image is empty, skipping validation") + return + + if not isinstance(image, str): + logger.warning( + f"Cluster config image should be a string, got {type(image).__name__}: {image}" + ) + return # Skip validation for malformed image + + is_compatible, is_warning, message = validate_ray_version_compatibility(image) + if not is_compatible: + raise ValueError(f"Cluster config image: {message}") + elif is_warning: + warnings.warn(f"Cluster config image: {message}") + def status( self, print_to_console: bool = True ) -> Tuple[CodeflareRayJobStatus, bool]: diff --git a/src/codeflare_sdk/ray/rayjobs/test_rayjob.py b/src/codeflare_sdk/ray/rayjobs/test_rayjob.py index 1ecd4b48..6827ed03 100644 --- a/src/codeflare_sdk/ray/rayjobs/test_rayjob.py +++ b/src/codeflare_sdk/ray/rayjobs/test_rayjob.py @@ -1,4 +1,4 @@ -# Copyright 2025 IBM, Red Hat +# Copyright 2022-2025 IBM, Red Hat # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. @@ -18,6 +18,7 @@ from codeflare_sdk.ray.rayjobs.rayjob import RayJob from codeflare_sdk.ray.cluster.config import ClusterConfiguration +from codeflare_sdk.ray.rayjobs.config import ManagedClusterConfig def test_rayjob_submit_success(mocker): @@ -992,3 +993,153 @@ def test_build_ray_cluster_spec_with_gcs_ft(mocker): gcs_ft = spec["gcsFaultToleranceOptions"] assert gcs_ft["redisAddress"] == "redis://redis-service:6379" assert gcs_ft["externalStorageNamespace"] == "storage-ns" + + +class TestRayVersionValidation: + """Test Ray version validation in RayJob.""" + + def test_submit_with_cluster_config_compatible_image_passes(self, mocker): + """Test that submission passes with compatible cluster_config image.""" + mocker.patch("kubernetes.config.load_kube_config") + mock_api_class = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + mock_api_instance = MagicMock() + mock_api_class.return_value = mock_api_instance + mock_api_instance.submit_job.return_value = True + + cluster_config = ManagedClusterConfig(image=f"ray:{RAY_VERSION}") + + rayjob = RayJob( + job_name="test-job", + cluster_config=cluster_config, + namespace="test-namespace", + entrypoint="python script.py", + ) + + # Should not raise any validation errors + result = rayjob.submit() + assert result == "test-job" + + def test_submit_with_cluster_config_incompatible_image_fails(self, mocker): + """Test that submission fails with incompatible cluster_config image.""" + mocker.patch("kubernetes.config.load_kube_config") + mock_api_class = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + mock_api_instance = MagicMock() + mock_api_class.return_value = mock_api_instance + + cluster_config = ManagedClusterConfig(image="ray:2.8.0") # Different version + + rayjob = RayJob( + job_name="test-job", + cluster_config=cluster_config, + namespace="test-namespace", + entrypoint="python script.py", + ) + + # Should raise ValueError for version mismatch + with pytest.raises( + ValueError, match="Cluster config image: Ray version mismatch detected" + ): + rayjob.submit() + + def test_validate_ray_version_compatibility_method(self, mocker): + """Test the _validate_ray_version_compatibility method directly.""" + mocker.patch("kubernetes.config.load_kube_config") + mock_api_class = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + mock_api_instance = MagicMock() + mock_api_class.return_value = mock_api_instance + + rayjob = RayJob( + job_name="test-job", + cluster_name="test-cluster", + namespace="test-namespace", + entrypoint="python script.py", + ) + + # Test with no cluster_config (should not raise) + rayjob._validate_ray_version_compatibility() # Should not raise + + # Test with compatible cluster_config version + rayjob._cluster_config = ManagedClusterConfig(image=f"ray:{RAY_VERSION}") + rayjob._validate_ray_version_compatibility() # Should not raise + + # Test with incompatible cluster_config version + rayjob._cluster_config = ManagedClusterConfig(image="ray:2.8.0") + with pytest.raises( + ValueError, match="Cluster config image: Ray version mismatch detected" + ): + rayjob._validate_ray_version_compatibility() + + # Test with unknown cluster_config version (should warn but not fail) + rayjob._cluster_config = ManagedClusterConfig(image="custom-image:latest") + with pytest.warns( + UserWarning, match="Cluster config image: Cannot determine Ray version" + ): + rayjob._validate_ray_version_compatibility() + + def test_validate_cluster_config_image_method(self, mocker): + """Test the _validate_cluster_config_image method directly.""" + mocker.patch("kubernetes.config.load_kube_config") + mock_api_class = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + mock_api_instance = MagicMock() + mock_api_class.return_value = mock_api_instance + + rayjob = RayJob( + job_name="test-job", + cluster_config=ManagedClusterConfig(), + namespace="test-namespace", + entrypoint="python script.py", + ) + + # Test with no image (should not raise) + rayjob._validate_cluster_config_image() # Should not raise + + # Test with compatible image + rayjob._cluster_config.image = f"ray:{RAY_VERSION}" + rayjob._validate_cluster_config_image() # Should not raise + + # Test with incompatible image + rayjob._cluster_config.image = "ray:2.8.0" + with pytest.raises( + ValueError, match="Cluster config image: Ray version mismatch detected" + ): + rayjob._validate_cluster_config_image() + + # Test with unknown image (should warn but not fail) + rayjob._cluster_config.image = "custom-image:latest" + with pytest.warns( + UserWarning, match="Cluster config image: Cannot determine Ray version" + ): + rayjob._validate_cluster_config_image() + + def test_validate_cluster_config_image_edge_cases(self, mocker): + """Test edge cases in _validate_cluster_config_image method.""" + mocker.patch("kubernetes.config.load_kube_config") + mock_api_class = mocker.patch("codeflare_sdk.ray.rayjobs.rayjob.RayjobApi") + mock_api_instance = MagicMock() + mock_api_class.return_value = mock_api_instance + + rayjob = RayJob( + job_name="test-job", + cluster_config=ManagedClusterConfig(), + namespace="test-namespace", + entrypoint="python script.py", + ) + + # Test with None image (should not raise) + rayjob._cluster_config.image = None + rayjob._validate_cluster_config_image() # Should not raise + + # Test with empty string image (should not raise) + rayjob._cluster_config.image = "" + rayjob._validate_cluster_config_image() # Should not raise + + # Test with non-string image (should log warning and skip) + rayjob._cluster_config.image = 123 + rayjob._validate_cluster_config_image() # Should log warning and not raise + + # Test with cluster config that has no image attribute + class MockClusterConfig: + pass + + rayjob._cluster_config = MockClusterConfig() + rayjob._validate_cluster_config_image() # Should not raise