diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_deployment/template/accelerator_map_schema.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_deployment/template/accelerator_map_schema.py new file mode 100644 index 000000000000..e5987d1e657d --- /dev/null +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_deployment/template/accelerator_map_schema.py @@ -0,0 +1,38 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +# pylint: disable=unused-argument + +import logging + +from marshmallow import fields, post_load + +from azure.ai.ml._schema.core.schema import PatchedSchemaMeta +from azure.ai.ml._utils._experimental import experimental + +module_logger = logging.getLogger(__name__) + + +@experimental +class AcceleratorMapSchema(metaclass=PatchedSchemaMeta): + """Schema for AcceleratorMap.""" + + accelerator_type = fields.Str( + required=True, + metadata={"description": "The type of accelerator (e.g. H100_80GB, H200_141GB, A100_80GB)."}, + ) + number_of_accelerators_per_model_instance = fields.Int( + required=True, + metadata={"description": "Number of accelerators per model instance."}, + ) + default = fields.Bool( + load_default=None, + metadata={"description": "Whether this is the default accelerator map."}, + ) + + @post_load + def make(self, data, **kwargs): + from azure.ai.ml.entities._deployment.accelerator_map import AcceleratorMap + + return AcceleratorMap(**data) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_deployment/template/deployment_template.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_deployment/template/deployment_template.py index 32e0e489e37e..54faa576cbf2 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_deployment/template/deployment_template.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/_deployment/template/deployment_template.py @@ -11,7 +11,6 @@ from marshmallow import fields, post_load from azure.ai.ml._schema.assets.environment import AnonymousEnvironmentSchema, EnvironmentSchema -from azure.ai.ml.constants._common import AzureMLResourceType from azure.ai.ml._schema.core.fields import ( ArmVersionedStr, NestedField, @@ -21,7 +20,9 @@ VersionField, ) from azure.ai.ml._utils._experimental import experimental +from azure.ai.ml.constants._common import AzureMLResourceType +from .accelerator_map_schema import AcceleratorMapSchema from .probe_settings_schema import ProbeSettingsSchema from .request_settings_schema import RequestSettingsSchema @@ -56,6 +57,7 @@ class DeploymentTemplateSchema(PathAwareSchema): ) scoring_port = fields.Int() scoring_path = fields.Str() + accelerator_maps = fields.List(NestedField(AcceleratorMapSchema)) @post_load def make(self, data, **kwargs): # pylint: disable=unused-argument diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/assets/model.py b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/assets/model.py index 8968dc1e99cb..6af76dbce57f 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/assets/model.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/_schema/assets/model.py @@ -8,9 +8,9 @@ from marshmallow import fields, post_load, pre_dump +from azure.ai.ml._schema.assets.default_deployment_template import DefaultDeploymentTemplateSchema from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField from azure.ai.ml._schema.core.intellectual_property import IntellectualPropertySchema -from azure.ai.ml._schema.assets.default_deployment_template import DefaultDeploymentTemplateSchema from azure.ai.ml._schema.core.schema import PathAwareSchema from azure.ai.ml._schema.job import CreationContextSchema from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AssetTypes, AzureMLResourceType @@ -46,6 +46,7 @@ class ModelSchema(PathAwareSchema): intellectual_property = ExperimentalField(NestedField(IntellectualPropertySchema, required=False), dump_only=True) system_metadata = fields.Dict() default_deployment_template = NestedField(DefaultDeploymentTemplateSchema, required=False) + allowed_deployment_templates = fields.List(NestedField(DefaultDeploymentTemplateSchema), required=False) @pre_dump def validate(self, data, **kwargs): diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py index 8abc7ddb31f4..7eb4ec766937 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py @@ -36,12 +36,10 @@ from ._assets._artifacts.index import Index from ._assets._artifacts.model import Model from ._assets.asset import Asset +from ._assets.default_deployment_template import DefaultDeploymentTemplate from ._assets.environment import BuildContext, Environment from ._assets.intellectual_property import IntellectualProperty -from ._assets.workspace_asset_reference import ( - WorkspaceAssetReference as WorkspaceModelReference, -) -from ._assets.default_deployment_template import DefaultDeploymentTemplate +from ._assets.workspace_asset_reference import WorkspaceAssetReference as WorkspaceModelReference from ._autogen_entities.models import ( AzureOpenAIDeployment, MarketplacePlan, @@ -55,40 +53,19 @@ from ._component.pipeline_component import PipelineComponent from ._component.spark_component import SparkComponent from ._compute._aml_compute_node_info import AmlComputeNodeInfo -from ._compute._custom_applications import ( - CustomApplications, - EndpointsSettings, - ImageSettings, - VolumeSettings, -) +from ._compute._custom_applications import CustomApplications, EndpointsSettings, ImageSettings, VolumeSettings from ._compute._image_metadata import ImageMetadata -from ._compute._schedule import ( - ComputePowerAction, - ComputeSchedules, - ComputeStartStopSchedule, - ScheduleState, -) +from ._compute._schedule import ComputePowerAction, ComputeSchedules, ComputeStartStopSchedule, ScheduleState from ._compute._setup_scripts import ScriptReference, SetupScripts from ._compute._usage import Usage, UsageName from ._compute._vm_size import VmSize from ._compute.aml_compute import AmlCompute, AmlComputeSshSettings from ._compute.compute import Compute, NetworkSettings -from ._compute.compute_instance import ( - AssignedUserConfiguration, - ComputeInstance, - ComputeInstanceSshSettings, -) +from ._compute.compute_instance import AssignedUserConfiguration, ComputeInstance, ComputeInstanceSshSettings from ._compute.kubernetes_compute import KubernetesCompute -from ._compute.synapsespark_compute import ( - AutoPauseSettings, - AutoScaleSettings, - SynapseSparkCompute, -) +from ._compute.synapsespark_compute import AutoPauseSettings, AutoScaleSettings, SynapseSparkCompute from ._compute.unsupported_compute import UnsupportedCompute -from ._compute.virtual_machine_compute import ( - VirtualMachineCompute, - VirtualMachineSshSettings, -) +from ._compute.virtual_machine_compute import VirtualMachineCompute, VirtualMachineSshSettings from ._credentials import ( AadCredentialConfiguration, AccessKeyConfiguration, @@ -108,13 +85,10 @@ from ._data_import.data_import import DataImport from ._data_import.schedule import ImportDataSchedule from ._datastore.adls_gen1 import AzureDataLakeGen1Datastore -from ._datastore.azure_storage import ( - AzureBlobDatastore, - AzureDataLakeGen2Datastore, - AzureFileDatastore, -) +from ._datastore.azure_storage import AzureBlobDatastore, AzureDataLakeGen2Datastore, AzureFileDatastore from ._datastore.datastore import Datastore from ._datastore.one_lake import OneLakeArtifact, OneLakeDatastore +from ._deployment.accelerator_map import AcceleratorMap from ._deployment.batch_deployment import BatchDeployment from ._deployment.batch_job import BatchJob from ._deployment.code_configuration import CodeConfiguration @@ -122,11 +96,7 @@ from ._deployment.data_asset import DataAsset from ._deployment.data_collector import DataCollector from ._deployment.deployment_collection import DeploymentCollection -from ._deployment.deployment_settings import ( - BatchRetrySettings, - OnlineRequestSettings, - ProbeSettings, -) +from ._deployment.deployment_settings import BatchRetrySettings, OnlineRequestSettings, ProbeSettings from ._deployment.deployment_template import DeploymentTemplate from ._deployment.model_batch_deployment import ModelBatchDeployment from ._deployment.model_batch_deployment_settings import ModelBatchDeploymentSettings @@ -136,16 +106,10 @@ ManagedOnlineDeployment, OnlineDeployment, ) -from ._deployment.pipeline_component_batch_deployment import ( - PipelineComponentBatchDeployment, -) +from ._deployment.pipeline_component_batch_deployment import PipelineComponentBatchDeployment from ._deployment.request_logging import RequestLogging from ._deployment.resource_requirements_settings import ResourceRequirementsSettings -from ._deployment.scale_settings import ( - DefaultScaleSettings, - OnlineScaleSettings, - TargetUtilizationScaleSettings, -) +from ._deployment.scale_settings import DefaultScaleSettings, OnlineScaleSettings, TargetUtilizationScaleSettings from ._endpoint.batch_endpoint import BatchEndpoint from ._endpoint.endpoint import Endpoint from ._endpoint.online_endpoint import ( @@ -160,14 +124,10 @@ from ._feature_set.feature import Feature from ._feature_set.feature_set_backfill_metadata import FeatureSetBackfillMetadata from ._feature_set.feature_set_backfill_request import FeatureSetBackfillRequest -from ._feature_set.feature_set_materialization_metadata import ( - FeatureSetMaterializationMetadata, -) +from ._feature_set.feature_set_materialization_metadata import FeatureSetMaterializationMetadata from ._feature_set.feature_set_specification import FeatureSetSpecification from ._feature_set.feature_window import FeatureWindow -from ._feature_set.materialization_compute_resource import ( - MaterializationComputeResource, -) +from ._feature_set.materialization_compute_resource import MaterializationComputeResource from ._feature_set.materialization_settings import MaterializationSettings from ._feature_set.materialization_type import MaterializationType from ._feature_store.feature_store import FeatureStore @@ -183,15 +143,9 @@ from ._job.input_port import InputPort from ._job.job import Job from ._job.job_limits import CommandJobLimits -from ._job.job_resources import JobResources from ._job.job_resource_configuration import JobResourceConfiguration -from ._job.job_service import ( - JobService, - JupyterLabJobService, - SshJobService, - TensorBoardJobService, - VsCodeJobService, -) +from ._job.job_resources import JobResources +from ._job.job_service import JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService from ._job.parallel.parallel_task import ParallelTask from ._job.parallel.retry_settings import RetrySettings from ._job.parameterized_command import ParameterizedCommand @@ -207,12 +161,7 @@ from ._monitoring.alert_notification import AlertNotification from ._monitoring.compute import ServerlessSparkCompute from ._monitoring.definition import MonitorDefinition -from ._monitoring.input_data import ( - FixedInputData, - MonitorInputData, - StaticInputData, - TrailingInputData, -) +from ._monitoring.input_data import FixedInputData, MonitorInputData, StaticInputData, TrailingInputData from ._monitoring.schedule import MonitorSchedule from ._monitoring.signals import ( BaselineDataRange, @@ -260,6 +209,7 @@ from ._schedule.trigger import CronTrigger, RecurrencePattern, RecurrenceTrigger from ._system_data import SystemData from ._validation import ValidationResult +from ._workspace._ai_workspaces.capability_host import CapabilityHost, CapabilityHostKind from ._workspace._ai_workspaces.hub import Hub from ._workspace._ai_workspaces.project import Project from ._workspace.compute_runtime import ComputeRuntime @@ -300,15 +250,7 @@ from ._workspace.private_endpoint import EndpointConnection, PrivateEndpoint from ._workspace.serverless_compute import ServerlessComputeSettings from ._workspace.workspace import Workspace -from ._workspace._ai_workspaces.capability_host import ( - CapabilityHost, - CapabilityHostKind, -) -from ._workspace.workspace_keys import ( - ContainerRegistryCredential, - NotebookAccessKeys, - WorkspaceKeys, -) +from ._workspace.workspace_keys import ContainerRegistryCredential, NotebookAccessKeys, WorkspaceKeys __all__ = [ "Resource", @@ -341,6 +283,7 @@ "Deployment", "BatchDeployment", "DeploymentTemplate", + "AcceleratorMap", "BatchJob", "CodeConfiguration", "Endpoint", @@ -573,6 +516,7 @@ "LocalSource", "IndexModelConfiguration", "DefaultDeploymentTemplate", + "AcceleratorMap", ] # Allow importing these types for backwards compatibility diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_assets/_artifacts/model.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_assets/_artifacts/model.py index f69a4586c946..b897575a0b08 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_assets/_artifacts/model.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_assets/_artifacts/model.py @@ -3,19 +3,19 @@ # --------------------------------------------------------- from os import PathLike from pathlib import Path -from typing import Any, Dict, Optional, Union +from typing import Any, Dict, List, Optional, Union +from azure.ai.ml._restclient.v2021_10_01_dataplanepreview.models import ( + ModelVersionData, + ModelVersionDefaultDeploymentTemplate, + ModelVersionDetails, +) from azure.ai.ml._restclient.v2023_04_01_preview.models import ( FlavorData, ModelContainer, ModelVersion, ModelVersionProperties, ) -from azure.ai.ml._restclient.v2021_10_01_dataplanepreview.models import ( - ModelVersionDetails, - ModelVersionData, - ModelVersionDefaultDeploymentTemplate, -) from azure.ai.ml._schema import ModelSchema from azure.ai.ml._utils._arm_id_utils import AMLNamedArmId, AMLVersionedArmId from azure.ai.ml._utils._asset_utils import get_ignore_file, get_object_hash @@ -27,8 +27,8 @@ AssetTypes, ) from azure.ai.ml.entities._assets import Artifact -from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty from azure.ai.ml.entities._assets.default_deployment_template import DefaultDeploymentTemplate +from azure.ai.ml.entities._assets.intellectual_property import IntellectualProperty from azure.ai.ml.entities._system_data import SystemData from azure.ai.ml.entities._util import get_sha256_string, load_from_dict @@ -63,6 +63,8 @@ class Model(Artifact): # pylint: disable=too-many-instance-attributes :type stage: Optional[str] :param default_deployment_template: The default deployment template reference for the model. Defaults to None. :type default_deployment_template: Optional[DefaultDeploymentTemplate] + :param allowed_deployment_templates: List of allowed deployment template references for the model. + :type allowed_deployment_templates: Optional[list[DefaultDeploymentTemplate]] :param kwargs: A dictionary of additional configuration parameters. :type kwargs: Optional[dict] @@ -90,6 +92,7 @@ def __init__( properties: Optional[Dict] = None, stage: Optional[str] = None, default_deployment_template: Optional[DefaultDeploymentTemplate] = None, + allowed_deployment_templates: Optional[List[DefaultDeploymentTemplate]] = None, **kwargs: Any, ) -> None: self.job_name = kwargs.pop("job_name", None) @@ -115,6 +118,13 @@ def __init__( self.default_deployment_template = DefaultDeploymentTemplate(**default_deployment_template) else: self.default_deployment_template = default_deployment_template + # Handle allowed_deployment_templates - can be list of dicts or DefaultDeploymentTemplate objects + self.allowed_deployment_templates: Optional[List[DefaultDeploymentTemplate]] = None + if allowed_deployment_templates: + self.allowed_deployment_templates = [ + DefaultDeploymentTemplate(**item) if isinstance(item, dict) else item + for item in allowed_deployment_templates + ] if self._is_anonymous and self.path: _ignore_file = get_ignore_file(self.path) _upload_hash = get_object_hash(self.path, _ignore_file) @@ -168,6 +178,25 @@ def _from_rest_object(cls, model_rest_object: Union[ModelVersion, ModelVersionDa asset_id=getattr(rest_model_version.default_deployment_template, "asset_id", None) ) + # Handle allowed_deployment_templates from REST object + allowed_deployment_templates = None + if ( + hasattr(rest_model_version, "allowed_deployment_templates") + and rest_model_version.allowed_deployment_templates + ): + raw_list = rest_model_version.allowed_deployment_templates + if isinstance(raw_list, list): + allowed_deployment_templates = [] + for item in raw_list: + if isinstance(item, dict): + allowed_deployment_templates.append( + DefaultDeploymentTemplate(asset_id=item.get("asset_id") or item.get("assetId")) + ) + else: + allowed_deployment_templates.append( + DefaultDeploymentTemplate(asset_id=getattr(item, "asset_id", None)) + ) + model = Model( id=model_rest_object.id, name=arm_id.asset_name, @@ -189,6 +218,7 @@ def _from_rest_object(cls, model_rest_object: Union[ModelVersion, ModelVersionDa ), system_metadata=model_system_metadata, default_deployment_template=default_deployment_template, + allowed_deployment_templates=allowed_deployment_templates, ) return model @@ -209,7 +239,7 @@ def _from_container_rest_object(cls, model_container_rest_object: ModelContainer return model def _to_rest_object(self) -> Union[ModelVersionData, ModelVersion]: - if self.default_deployment_template: + if self.default_deployment_template or self.allowed_deployment_templates: model_version = ModelVersionDetails( description=self.description, tags=self.tags, @@ -224,9 +254,15 @@ def _to_rest_object(self) -> Union[ModelVersionData, ModelVersion]: ) model_version.system_metadata = self._system_metadata if hasattr(self, "_system_metadata") else None - model_version.default_deployment_template = ModelVersionDefaultDeploymentTemplate( - asset_id=self.default_deployment_template.asset_id - ) + if self.default_deployment_template: + model_version.default_deployment_template = ModelVersionDefaultDeploymentTemplate( + asset_id=self.default_deployment_template.asset_id + ) + if self.allowed_deployment_templates: + model_version.allowed_deployment_templates = [ + ModelVersionDefaultDeploymentTemplate(asset_id=adt.asset_id) + for adt in self.allowed_deployment_templates + ] model_version_resource = ModelVersionData(properties=model_version) return model_version_resource diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/accelerator_map.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/accelerator_map.py new file mode 100644 index 000000000000..e712d76db328 --- /dev/null +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/accelerator_map.py @@ -0,0 +1,84 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +from typing import Optional + +from azure.ai.ml._utils._experimental import experimental + + +@experimental +class AcceleratorMap: + """Accelerator map for a deployment template, describing the accelerator type + and how many accelerators are needed per model instance. + + :param accelerator_type: The type of accelerator (e.g. "H100_80GB", "H200_141GB", "A100_80GB"). + :type accelerator_type: str + :param number_of_accelerators_per_model_instance: Number of accelerators per model instance. + :type number_of_accelerators_per_model_instance: int + :param default: Whether this is the default accelerator map for the deployment template. + :type default: bool + """ + + def __init__( + self, + *, + accelerator_type: str, + number_of_accelerators_per_model_instance: int, # pylint: disable=name-too-long + default: Optional[bool] = None, + **kwargs, # pylint: disable=unused-argument + ) -> None: + self.accelerator_type = accelerator_type + self.number_of_accelerators_per_model_instance = number_of_accelerators_per_model_instance + self.default = default + + def _to_rest_dict(self) -> dict: + """Convert to REST API dictionary. + + :return: Dictionary with camelCase keys for REST API. + :rtype: dict + """ + result = { + "acceleratorType": self.accelerator_type, + "numberOfAcceleratorsPerModelInstance": self.number_of_accelerators_per_model_instance, + } + if self.default is not None: + result["default"] = self.default + return result + + @classmethod + def _from_rest_dict(cls, data: dict) -> "AcceleratorMap": + """Create AcceleratorMap from REST API dictionary. + + :param data: REST dictionary with camelCase keys. + :type data: dict + :return: AcceleratorMap instance. + :rtype: AcceleratorMap + """ + return cls( + accelerator_type=data.get("acceleratorType", ""), + number_of_accelerators_per_model_instance=data.get("numberOfAcceleratorsPerModelInstance", 0), + default=data.get("default"), + ) + + def __eq__(self, other: object) -> bool: + if not isinstance(other, AcceleratorMap): + return NotImplemented + return ( + self.accelerator_type == other.accelerator_type + and self.number_of_accelerators_per_model_instance == other.number_of_accelerators_per_model_instance + and self.default == other.default + ) + + def __ne__(self, other: object) -> bool: + result = self.__eq__(other) + if result is NotImplemented: + return result + return not result + + def __repr__(self) -> str: + return ( + f"AcceleratorMap(accelerator_type={self.accelerator_type!r}, " + f"number_of_accelerators_per_model_instance={self.number_of_accelerators_per_model_instance!r}, " + f"default={self.default!r})" + ) diff --git a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/deployment_template.py b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/deployment_template.py index ee16ca3b6f2e..37e1fd8056a9 100644 --- a/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/deployment_template.py +++ b/sdk/ml/azure-ai-ml/azure/ai/ml/entities/_deployment/deployment_template.py @@ -9,13 +9,13 @@ from os import PathLike from pathlib import Path -from typing import Any, Dict, Optional, Union, IO, AnyStr +from typing import IO, Any, AnyStr, Dict, List, Optional, Union from azure.ai.ml._utils._experimental import experimental -from azure.ai.ml.entities._mixins import RestTranslatableMixin from azure.ai.ml.entities._assets import Environment - +from azure.ai.ml.entities._deployment.accelerator_map import AcceleratorMap from azure.ai.ml.entities._deployment.deployment_template_settings import OnlineRequestSettings, ProbeSettings +from azure.ai.ml.entities._mixins import RestTranslatableMixin from azure.ai.ml.entities._resource import Resource @@ -51,6 +51,9 @@ class DeploymentTemplate(Resource, RestTranslatableMixin): # pylint: disable=to :type app_insights_enabled: bool :param stage: Stage of the deployment template. Can be "Active" or "Archived". :type stage: str + :param accelerator_maps: List of accelerator maps describing the accelerator types + and their configurations for this deployment template. + :type accelerator_maps: list[~azure.ai.ml.entities.AcceleratorMap] """ def __init__( # pylint: disable=too-many-locals @@ -77,6 +80,7 @@ def __init__( # pylint: disable=too-many-locals type: Optional[str] = None, deployment_template_type: Optional[str] = None, stage: Optional[str] = None, + accelerator_maps: Optional[List[AcceleratorMap]] = None, **kwargs, ): # Extract kwargs that should be passed to parent @@ -107,6 +111,7 @@ def __init__( # pylint: disable=too-many-locals self.type = type self.deployment_template_type = deployment_template_type self.stage = stage + self.accelerator_maps = accelerator_maps # Private flag to track if this template came from the service (and thus should exclude # immutable fields on update) @@ -306,6 +311,15 @@ def dump(self, dest: Union[str, PathLike, IO[AnyStr]] = None, **kwargs: Any) -> result["environment_variables"] = self.environment_variables # type: ignore[assignment] if self.app_insights_enabled is not None: result["app_insights_enabled"] = self.app_insights_enabled # type: ignore[assignment] + if self.accelerator_maps: + result["accelerator_maps"] = [ + { + "accelerator_type": am.accelerator_type, + "number_of_accelerators_per_model_instance": am.number_of_accelerators_per_model_instance, + **({"default": am.default} if am.default is not None else {}), + } + for am in self.accelerator_maps + ] return result @@ -371,9 +385,14 @@ def get_value(source, key, default=None): stage = get_value(properties, "stage") or get_value(obj, "stage") type_field = get_value(properties, "type") or get_value(obj, "type") + accelerator_maps_data = get_value(properties, "acceleratorMaps") or get_value(obj, "accelerator_maps") + accelerator_maps_list = None + if accelerator_maps_data and isinstance(accelerator_maps_data, list): + accelerator_maps_list = [AcceleratorMap._from_rest_dict(am) for am in accelerator_maps_data] + # Handle string representations from properties - they come as JSON strings - import json import ast + import json # Parse tags if it's a string if isinstance(tags, str): @@ -442,6 +461,7 @@ def get_value(source, key, default=None): model_mount_path=model_mount_path, # Include model mount path stage=stage, # Include stage for archive/restore functionality type=type_field, # Include type field from REST response + accelerator_maps=accelerator_maps_list, # Include accelerator maps ) # Mark this template as coming from the service so it excludes immutable fields on @@ -576,6 +596,9 @@ def _to_rest_object(self) -> dict: instance_types_array = [str(self.allowed_instance_types)] result["allowedInstanceTypes"] = instance_types_array # type: ignore[assignment] + if self.accelerator_maps: + result["acceleratorMaps"] = [am._to_rest_dict() for am in self.accelerator_maps] + return result def _to_dict(self) -> Dict: @@ -653,6 +676,8 @@ def _to_dict(self) -> Dict: result["scoringPath"] = self.scoring_path if self.scoring_port is not None: result["scoringPort"] = self.scoring_port # type: ignore[assignment] + if self.accelerator_maps: + result["acceleratorMaps"] = [am._to_rest_dict() for am in self.accelerator_maps] return result diff --git a/sdk/ml/azure-ai-ml/tests/deployment_template/unittests/test_accelerator_map.py b/sdk/ml/azure-ai-ml/tests/deployment_template/unittests/test_accelerator_map.py new file mode 100644 index 000000000000..e15a6a1ad354 --- /dev/null +++ b/sdk/ml/azure-ai-ml/tests/deployment_template/unittests/test_accelerator_map.py @@ -0,0 +1,408 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +import os +import tempfile +from pathlib import Path + +import pytest + +from azure.ai.ml.entities._deployment.accelerator_map import AcceleratorMap +from azure.ai.ml.entities._deployment.deployment_template import DeploymentTemplate + + +class TestAcceleratorMap: + """Tests for the AcceleratorMap entity.""" + + def test_basic_init(self): + """Test basic AcceleratorMap initialization with required fields.""" + am = AcceleratorMap( + accelerator_type="H100_80GB", + number_of_accelerators_per_model_instance=4, + ) + + assert am.accelerator_type == "H100_80GB" + assert am.number_of_accelerators_per_model_instance == 4 + assert am.default is None + + def test_full_init(self): + """Test AcceleratorMap initialization with all fields.""" + am = AcceleratorMap( + accelerator_type="H200_141GB", + number_of_accelerators_per_model_instance=2, + default=True, + ) + + assert am.accelerator_type == "H200_141GB" + assert am.number_of_accelerators_per_model_instance == 2 + assert am.default is True + + def test_default_false(self): + """Test AcceleratorMap with default=False.""" + am = AcceleratorMap( + accelerator_type="A100_80GB", + number_of_accelerators_per_model_instance=8, + default=False, + ) + + assert am.default is False + + def test_to_rest_dict(self): + """Test conversion to REST API dictionary.""" + am = AcceleratorMap( + accelerator_type="H100_80GB", + number_of_accelerators_per_model_instance=4, + default=True, + ) + + rest_dict = am._to_rest_dict() + + assert rest_dict["acceleratorType"] == "H100_80GB" + assert rest_dict["numberOfAcceleratorsPerModelInstance"] == 4 + assert rest_dict["default"] is True + + def test_to_rest_dict_without_default(self): + """Test conversion to REST API dictionary without default field.""" + am = AcceleratorMap( + accelerator_type="H100_80GB", + number_of_accelerators_per_model_instance=4, + ) + + rest_dict = am._to_rest_dict() + + assert rest_dict["acceleratorType"] == "H100_80GB" + assert rest_dict["numberOfAcceleratorsPerModelInstance"] == 4 + assert "default" not in rest_dict + + def test_from_rest_dict(self): + """Test creation from REST API dictionary.""" + rest_dict = { + "acceleratorType": "H200_141GB", + "numberOfAcceleratorsPerModelInstance": 2, + "default": True, + } + + am = AcceleratorMap._from_rest_dict(rest_dict) + + assert am.accelerator_type == "H200_141GB" + assert am.number_of_accelerators_per_model_instance == 2 + assert am.default is True + + def test_from_rest_dict_without_default(self): + """Test creation from REST API dictionary without default.""" + rest_dict = { + "acceleratorType": "A100_80GB", + "numberOfAcceleratorsPerModelInstance": 8, + } + + am = AcceleratorMap._from_rest_dict(rest_dict) + + assert am.accelerator_type == "A100_80GB" + assert am.number_of_accelerators_per_model_instance == 8 + assert am.default is None + + def test_round_trip(self): + """Test round-trip: entity -> REST dict -> entity.""" + original = AcceleratorMap( + accelerator_type="H100_80GB", + number_of_accelerators_per_model_instance=4, + default=True, + ) + + rest_dict = original._to_rest_dict() + restored = AcceleratorMap._from_rest_dict(rest_dict) + + assert original == restored + + def test_equality(self): + """Test equality comparison.""" + am1 = AcceleratorMap(accelerator_type="H100_80GB", number_of_accelerators_per_model_instance=4, default=True) + am2 = AcceleratorMap(accelerator_type="H100_80GB", number_of_accelerators_per_model_instance=4, default=True) + am3 = AcceleratorMap(accelerator_type="H200_141GB", number_of_accelerators_per_model_instance=2) + + assert am1 == am2 + assert am1 != am3 + + def test_inequality_with_other_types(self): + """Test inequality with non-AcceleratorMap objects.""" + am = AcceleratorMap(accelerator_type="H100_80GB", number_of_accelerators_per_model_instance=4) + + assert am != "not an accelerator map" + assert am != 42 + assert am != None # noqa: E711 + + def test_repr(self): + """Test string representation.""" + am = AcceleratorMap( + accelerator_type="H100_80GB", + number_of_accelerators_per_model_instance=4, + default=True, + ) + + repr_str = repr(am) + assert "H100_80GB" in repr_str + assert "4" in repr_str + assert "True" in repr_str + + +class TestDeploymentTemplateWithAcceleratorMaps: + """Tests for accelerator_maps integration in DeploymentTemplate.""" + + def test_init_with_accelerator_maps(self): + """Test DeploymentTemplate with accelerator_maps field.""" + maps = [ + AcceleratorMap(accelerator_type="H100_80GB", number_of_accelerators_per_model_instance=4, default=True), + AcceleratorMap(accelerator_type="H200_141GB", number_of_accelerators_per_model_instance=2), + ] + + template = DeploymentTemplate( + name="dt-with-accelerators", + version="1", + accelerator_maps=maps, + ) + + assert template.accelerator_maps is not None + assert len(template.accelerator_maps) == 2 + assert template.accelerator_maps[0].accelerator_type == "H100_80GB" + assert template.accelerator_maps[0].default is True + assert template.accelerator_maps[1].accelerator_type == "H200_141GB" + + def test_init_without_accelerator_maps(self): + """Test DeploymentTemplate without accelerator_maps defaults to None.""" + template = DeploymentTemplate(name="dt-no-accelerators", version="1") + + assert template.accelerator_maps is None + + def test_to_rest_object_with_accelerator_maps(self): + """Test _to_rest_object includes acceleratorMaps.""" + maps = [ + AcceleratorMap(accelerator_type="H100_80GB", number_of_accelerators_per_model_instance=4, default=True), + AcceleratorMap(accelerator_type="H200_141GB", number_of_accelerators_per_model_instance=2), + ] + + template = DeploymentTemplate( + name="dt1", + version="1", + environment="azureml://registries/reg1/environments/env1/versions/1", + allowed_instance_types="Standard_ND96isr_H100_v5 Standard_ND96isr_H200_v5", + accelerator_maps=maps, + ) + + rest_obj = template._to_rest_object() + + assert "acceleratorMaps" in rest_obj + assert len(rest_obj["acceleratorMaps"]) == 2 + assert rest_obj["acceleratorMaps"][0]["acceleratorType"] == "H100_80GB" + assert rest_obj["acceleratorMaps"][0]["numberOfAcceleratorsPerModelInstance"] == 4 + assert rest_obj["acceleratorMaps"][0]["default"] is True + assert rest_obj["acceleratorMaps"][1]["acceleratorType"] == "H200_141GB" + assert rest_obj["acceleratorMaps"][1]["numberOfAcceleratorsPerModelInstance"] == 2 + assert "default" not in rest_obj["acceleratorMaps"][1] + + def test_to_rest_object_without_accelerator_maps(self): + """Test _to_rest_object omits acceleratorMaps when None.""" + template = DeploymentTemplate(name="dt1", version="1") + + rest_obj = template._to_rest_object() + + assert "acceleratorMaps" not in rest_obj + + def test_from_rest_object_with_accelerator_maps(self): + """Test _from_rest_object deserializes acceleratorMaps.""" + rest_obj = { + "properties": { + "name": "dt1", + "version": "1", + "deploymentTemplateType": "ModelDeployment", + "environmentId": "azureml://registries/reg1/environments/env1/versions/1", + "allowedInstanceTypes": ["Standard_ND96isr_H100_v5"], + "defaultInstanceType": "Standard_ND96isr_H100_v5", + "instanceCount": 1, + "acceleratorMaps": [ + { + "acceleratorType": "H100_80GB", + "numberOfAcceleratorsPerModelInstance": 4, + "default": True, + }, + { + "acceleratorType": "H200_141GB", + "numberOfAcceleratorsPerModelInstance": 2, + }, + ], + }, + } + + template = DeploymentTemplate._from_rest_object(rest_obj) + + assert template.accelerator_maps is not None + assert len(template.accelerator_maps) == 2 + assert template.accelerator_maps[0].accelerator_type == "H100_80GB" + assert template.accelerator_maps[0].number_of_accelerators_per_model_instance == 4 + assert template.accelerator_maps[0].default is True + assert template.accelerator_maps[1].accelerator_type == "H200_141GB" + assert template.accelerator_maps[1].number_of_accelerators_per_model_instance == 2 + assert template.accelerator_maps[1].default is None + + def test_from_rest_object_without_accelerator_maps(self): + """Test _from_rest_object handles missing acceleratorMaps.""" + rest_obj = { + "properties": { + "name": "dt1", + "version": "1", + "deploymentTemplateType": "ModelDeployment", + "environmentId": "env1", + "allowedInstanceTypes": ["Standard_DS3_v2"], + "defaultInstanceType": "Standard_DS3_v2", + "instanceCount": 1, + }, + } + + template = DeploymentTemplate._from_rest_object(rest_obj) + + assert template.accelerator_maps is None + + def test_dump_with_accelerator_maps(self): + """Test dump() includes accelerator_maps.""" + maps = [ + AcceleratorMap(accelerator_type="H100_80GB", number_of_accelerators_per_model_instance=4, default=True), + AcceleratorMap(accelerator_type="H200_141GB", number_of_accelerators_per_model_instance=2), + ] + + template = DeploymentTemplate( + name="dt1", + version="1", + accelerator_maps=maps, + ) + + dumped = template.dump() + + assert "accelerator_maps" in dumped + assert len(dumped["accelerator_maps"]) == 2 + assert dumped["accelerator_maps"][0]["accelerator_type"] == "H100_80GB" + assert dumped["accelerator_maps"][0]["number_of_accelerators_per_model_instance"] == 4 + assert dumped["accelerator_maps"][0]["default"] is True + assert dumped["accelerator_maps"][1]["accelerator_type"] == "H200_141GB" + assert dumped["accelerator_maps"][1]["number_of_accelerators_per_model_instance"] == 2 + assert "default" not in dumped["accelerator_maps"][1] + + def test_to_dict_with_accelerator_maps(self): + """Test _to_dict() includes acceleratorMaps.""" + maps = [ + AcceleratorMap(accelerator_type="H100_80GB", number_of_accelerators_per_model_instance=4, default=True), + ] + + template = DeploymentTemplate(name="dt1", version="1", accelerator_maps=maps) + + result = template._to_dict() + + assert "acceleratorMaps" in result + assert len(result["acceleratorMaps"]) == 1 + assert result["acceleratorMaps"][0]["acceleratorType"] == "H100_80GB" + + def test_round_trip_rest_object(self): + """Test full round-trip: entity -> REST -> entity.""" + original_maps = [ + AcceleratorMap(accelerator_type="H100_80GB", number_of_accelerators_per_model_instance=4, default=True), + AcceleratorMap(accelerator_type="H200_141GB", number_of_accelerators_per_model_instance=2), + ] + + original = DeploymentTemplate( + name="dt1", + version="1", + description="Test template with accelerator maps", + accelerator_maps=original_maps, + ) + + rest_obj = {"properties": original._to_rest_object()} + restored = DeploymentTemplate._from_rest_object(rest_obj) + + assert restored.name == original.name + assert restored.version == original.version + assert restored.accelerator_maps is not None + assert len(restored.accelerator_maps) == 2 + assert restored.accelerator_maps[0] == original_maps[0] + assert restored.accelerator_maps[1] == original_maps[1] + + def test_str_with_accelerator_maps(self): + """Test __str__ includes accelerator maps.""" + maps = [ + AcceleratorMap(accelerator_type="H100_80GB", number_of_accelerators_per_model_instance=4), + ] + + template = DeploymentTemplate(name="dt1", version="1", accelerator_maps=maps) + + str_repr = str(template) + assert "acceleratorMaps" in str_repr + assert "H100_80GB" in str_repr + + +class TestAcceleratorMapSchema: + """Tests for the AcceleratorMap marshmallow schema.""" + + @pytest.fixture + def dt_schema(self): + """Create a DeploymentTemplateSchema instance.""" + from azure.ai.ml._schema._deployment.template.deployment_template import DeploymentTemplateSchema + from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY + + temp_dir = tempfile.mkdtemp() + context = {BASE_PATH_CONTEXT_KEY: Path(temp_dir)} + schema = DeploymentTemplateSchema(context=context) + return schema + + def test_load_with_accelerator_maps(self, dt_schema): + """Test loading DeploymentTemplate data with accelerator_maps via schema.""" + data = { + "name": "dt-with-maps", + "version": "1", + "accelerator_maps": [ + { + "accelerator_type": "H100_80GB", + "number_of_accelerators_per_model_instance": 4, + "default": True, + }, + { + "accelerator_type": "H200_141GB", + "number_of_accelerators_per_model_instance": 2, + }, + ], + } + + result = dt_schema.load(data) + + assert isinstance(result, DeploymentTemplate) + assert result.accelerator_maps is not None + assert len(result.accelerator_maps) == 2 + assert isinstance(result.accelerator_maps[0], AcceleratorMap) + assert result.accelerator_maps[0].accelerator_type == "H100_80GB" + assert result.accelerator_maps[0].number_of_accelerators_per_model_instance == 4 + assert result.accelerator_maps[0].default is True + assert result.accelerator_maps[1].accelerator_type == "H200_141GB" + assert result.accelerator_maps[1].number_of_accelerators_per_model_instance == 2 + assert result.accelerator_maps[1].default is None + + def test_load_without_accelerator_maps(self, dt_schema): + """Test loading DeploymentTemplate data without accelerator_maps.""" + data = { + "name": "dt-no-maps", + "version": "1", + } + + result = dt_schema.load(data) + + assert isinstance(result, DeploymentTemplate) + assert result.accelerator_maps is None + + def test_load_with_empty_accelerator_maps(self, dt_schema): + """Test loading DeploymentTemplate data with empty accelerator_maps list.""" + data = { + "name": "dt-empty-maps", + "version": "1", + "accelerator_maps": [], + } + + result = dt_schema.load(data) + + assert isinstance(result, DeploymentTemplate) + assert result.accelerator_maps == [] diff --git a/sdk/ml/azure-ai-ml/tests/model/unittests/test_model_allowed_deployment_templates.py b/sdk/ml/azure-ai-ml/tests/model/unittests/test_model_allowed_deployment_templates.py new file mode 100644 index 000000000000..59ad185ace92 --- /dev/null +++ b/sdk/ml/azure-ai-ml/tests/model/unittests/test_model_allowed_deployment_templates.py @@ -0,0 +1,321 @@ +# --------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# --------------------------------------------------------- + +"""Unit tests for Model with allowed_deployment_templates functionality.""" + +from pathlib import Path +from unittest.mock import Mock + +import pytest + +from azure.ai.ml._restclient.v2021_10_01_dataplanepreview.models import ( + ModelVersionData, + ModelVersionDefaultDeploymentTemplate, + ModelVersionDetails, +) +from azure.ai.ml._restclient.v2023_04_01_preview.models import ModelVersion, ModelVersionProperties +from azure.ai.ml.entities import Model +from azure.ai.ml.entities._assets.default_deployment_template import DefaultDeploymentTemplate + + +@pytest.mark.unittest +@pytest.mark.production_experiences_test +class TestModelAllowedDeploymentTemplates: + """Test cases for Model entity with allowed_deployment_templates.""" + + def test_model_init_with_allowed_deployment_templates(self) -> None: + """Test creating a Model with allowed_deployment_templates.""" + templates = [ + DefaultDeploymentTemplate(asset_id="azureml://registries/reg1/deploymenttemplates/dt1/versions/1"), + DefaultDeploymentTemplate(asset_id="azureml://registries/reg1/deploymenttemplates/dt2/versions/1"), + ] + model = Model( + name="test-model", + version="1", + path="./model.pkl", + allowed_deployment_templates=templates, + ) + + assert model.allowed_deployment_templates is not None + assert len(model.allowed_deployment_templates) == 2 + assert model.allowed_deployment_templates[0].asset_id == templates[0].asset_id + assert model.allowed_deployment_templates[1].asset_id == templates[1].asset_id + + def test_model_init_with_allowed_deployment_templates_as_dicts(self) -> None: + """Test creating a Model with allowed_deployment_templates as list of dicts.""" + templates = [ + {"asset_id": "azureml://registries/reg1/deploymenttemplates/dt1/versions/1"}, + {"asset_id": "azureml://registries/reg1/deploymenttemplates/dt2/versions/1"}, + ] + model = Model( + name="test-model", + version="1", + path="./model.pkl", + allowed_deployment_templates=templates, # type: ignore[arg-type] + ) + + assert model.allowed_deployment_templates is not None + assert len(model.allowed_deployment_templates) == 2 + assert all(isinstance(t, DefaultDeploymentTemplate) for t in model.allowed_deployment_templates) + + def test_model_init_without_allowed_deployment_templates(self) -> None: + """Test creating a Model without allowed_deployment_templates.""" + model = Model( + name="test-model", + version="1", + path="./model.pkl", + ) + + assert model.allowed_deployment_templates is None + + def test_model_init_with_both_default_and_allowed(self) -> None: + """Test Model with both default and allowed deployment templates.""" + default = DefaultDeploymentTemplate(asset_id="azureml://registries/reg1/deploymenttemplates/dt1/labels/latest") + allowed = [ + DefaultDeploymentTemplate(asset_id="azureml://registries/reg1/deploymenttemplates/dt1/labels/latest"), + DefaultDeploymentTemplate(asset_id="azureml://registries/reg1/deploymenttemplates/dt2/labels/latest"), + ] + model = Model( + name="test-model", + version="1", + path="./model.pkl", + default_deployment_template=default, + allowed_deployment_templates=allowed, + ) + + assert model.default_deployment_template is not None + assert model.default_deployment_template.asset_id == default.asset_id + assert model.allowed_deployment_templates is not None + assert len(model.allowed_deployment_templates) == 2 + + def test_model_to_rest_object_with_allowed_only(self) -> None: + """Test _to_rest_object with only allowed_deployment_templates (no default).""" + allowed = [ + DefaultDeploymentTemplate(asset_id="azureml://registries/reg1/deploymenttemplates/dt1/versions/1"), + ] + model = Model( + name="test-model", + version="1", + path="./model.pkl", + allowed_deployment_templates=allowed, + ) + + rest_object = model._to_rest_object() + + # Should use ModelVersionData path + assert isinstance(rest_object, ModelVersionData) + assert isinstance(rest_object.properties, ModelVersionDetails) + assert rest_object.properties.allowed_deployment_templates is not None + assert len(rest_object.properties.allowed_deployment_templates) == 1 + assert rest_object.properties.allowed_deployment_templates[0].asset_id == allowed[0].asset_id + + def test_model_to_rest_object_with_both(self) -> None: + """Test _to_rest_object with both default and allowed templates.""" + default = DefaultDeploymentTemplate(asset_id="azureml://registries/reg1/deploymenttemplates/dt1/versions/1") + allowed = [ + DefaultDeploymentTemplate(asset_id="azureml://registries/reg1/deploymenttemplates/dt1/versions/1"), + DefaultDeploymentTemplate(asset_id="azureml://registries/reg1/deploymenttemplates/dt2/versions/1"), + ] + model = Model( + name="test-model", + version="1", + path="./model.pkl", + default_deployment_template=default, + allowed_deployment_templates=allowed, + ) + + rest_object = model._to_rest_object() + + assert isinstance(rest_object, ModelVersionData) + assert rest_object.properties.default_deployment_template is not None + assert rest_object.properties.default_deployment_template.asset_id == default.asset_id + assert rest_object.properties.allowed_deployment_templates is not None + assert len(rest_object.properties.allowed_deployment_templates) == 2 + + def test_model_to_rest_object_without_templates(self) -> None: + """Test _to_rest_object without any deployment templates uses ModelVersion path.""" + model = Model( + name="test-model", + version="1", + path="./model.pkl", + ) + + rest_object = model._to_rest_object() + + # Should use standard ModelVersion path + assert isinstance(rest_object, ModelVersion) + assert isinstance(rest_object.properties, ModelVersionProperties) + + def test_model_from_rest_object_with_allowed_deployment_templates_dict(self) -> None: + """Test _from_rest_object with allowed_deployment_templates as list of dicts.""" + rest_properties = Mock(spec=ModelVersionDetails) + rest_properties.description = "Test model" + rest_properties.tags = {} + rest_properties.properties = {} + rest_properties.flavors = {} + rest_properties.model_uri = "azureml://test/model" + rest_properties.model_type = "custom_model" + rest_properties.stage = None + rest_properties.job_name = None + rest_properties.intellectual_property = None + rest_properties.system_metadata = None + rest_properties.default_deployment_template = None + rest_properties.allowed_deployment_templates = [ + {"asset_id": "azureml://registries/reg1/deploymenttemplates/dt1/versions/1"}, + {"asset_id": "azureml://registries/reg1/deploymenttemplates/dt2/versions/1"}, + ] + + rest_object = Mock(spec=ModelVersionData) + rest_object.id = "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.MachineLearningServices/workspaces/ws/models/test-model/versions/1" + rest_object.properties = rest_properties + + mock_system_data = Mock() + mock_system_data.created_by = "test" + mock_system_data.created_at = None + mock_system_data.last_modified_by = None + mock_system_data.last_modified_at = None + rest_object.system_data = mock_system_data + + model = Model._from_rest_object(rest_object) + + assert model.allowed_deployment_templates is not None + assert len(model.allowed_deployment_templates) == 2 + assert all(isinstance(t, DefaultDeploymentTemplate) for t in model.allowed_deployment_templates) + assert ( + model.allowed_deployment_templates[0].asset_id + == "azureml://registries/reg1/deploymenttemplates/dt1/versions/1" + ) + assert ( + model.allowed_deployment_templates[1].asset_id + == "azureml://registries/reg1/deploymenttemplates/dt2/versions/1" + ) + + def test_model_from_rest_object_with_allowed_deployment_templates_objects(self) -> None: + """Test _from_rest_object with allowed_deployment_templates as list of objects.""" + rest_properties = Mock(spec=ModelVersionDetails) + rest_properties.description = "Test model" + rest_properties.tags = {} + rest_properties.properties = {} + rest_properties.flavors = {} + rest_properties.model_uri = "azureml://test/model" + rest_properties.model_type = "custom_model" + rest_properties.stage = None + rest_properties.job_name = None + rest_properties.intellectual_property = None + rest_properties.system_metadata = None + rest_properties.default_deployment_template = None + + template_obj1 = Mock() + template_obj1.asset_id = "azureml://registries/reg1/deploymenttemplates/dt1/versions/1" + template_obj2 = Mock() + template_obj2.asset_id = "azureml://registries/reg1/deploymenttemplates/dt2/versions/1" + rest_properties.allowed_deployment_templates = [template_obj1, template_obj2] + + rest_object = Mock(spec=ModelVersionData) + rest_object.id = "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.MachineLearningServices/workspaces/ws/models/test-model/versions/1" + rest_object.properties = rest_properties + + mock_system_data = Mock() + mock_system_data.created_by = "test" + mock_system_data.created_at = None + mock_system_data.last_modified_by = None + mock_system_data.last_modified_at = None + rest_object.system_data = mock_system_data + + model = Model._from_rest_object(rest_object) + + assert model.allowed_deployment_templates is not None + assert len(model.allowed_deployment_templates) == 2 + assert model.allowed_deployment_templates[0].asset_id == template_obj1.asset_id + assert model.allowed_deployment_templates[1].asset_id == template_obj2.asset_id + + def test_model_from_rest_object_without_allowed_deployment_templates(self) -> None: + """Test _from_rest_object without allowed_deployment_templates.""" + rest_properties = Mock(spec=ModelVersionProperties) + rest_properties.description = "Test model" + rest_properties.tags = {} + rest_properties.properties = {} + rest_properties.flavors = {} + rest_properties.model_uri = "azureml://test/model" + rest_properties.model_type = "custom_model" + rest_properties.stage = None + rest_properties.job_name = None + rest_properties.intellectual_property = None + + rest_object = Mock(spec=ModelVersion) + rest_object.id = "/subscriptions/sub/resourceGroups/rg/providers/Microsoft.MachineLearningServices/workspaces/ws/models/test-model/versions/1" + rest_object.properties = rest_properties + + mock_system_data = Mock() + mock_system_data.created_by = "test" + mock_system_data.created_at = None + mock_system_data.last_modified_by = None + mock_system_data.last_modified_at = None + rest_object.system_data = mock_system_data + + model = Model._from_rest_object(rest_object) + + assert model.allowed_deployment_templates is None + + def test_model_yaml_with_allowed_deployment_templates(self, tmp_path: Path) -> None: + """Test loading a Model from YAML with allowed_deployment_templates.""" + from azure.ai.ml import load_model + + model_file = tmp_path / "model.pkl" + model_file.write_text("dummy model") + + yaml_content = f""" +name: test-model +version: "1" +path: {str(model_file)} +description: Test model with allowed deployment templates +default_deployment_template: + asset_id: azureml://registries/reg1/deploymenttemplates/dt1/labels/latest +allowed_deployment_templates: + - asset_id: azureml://registries/reg1/deploymenttemplates/dt1/labels/latest + - asset_id: azureml://registries/reg1/deploymenttemplates/dt2/labels/latest +""" + yaml_file = tmp_path / "model.yml" + yaml_file.write_text(yaml_content) + + model = load_model(source=yaml_file) + + assert model.name == "test-model" + assert model.default_deployment_template is not None + assert model.default_deployment_template.asset_id is not None + assert "dt1" in model.default_deployment_template.asset_id + assert model.allowed_deployment_templates is not None + assert len(model.allowed_deployment_templates) == 2 + assert all(isinstance(t, DefaultDeploymentTemplate) for t in model.allowed_deployment_templates) + assert "dt1" in model.allowed_deployment_templates[0].asset_id + assert "dt2" in model.allowed_deployment_templates[1].asset_id + + def test_model_round_trip_rest_with_both_templates(self) -> None: + """Test round-trip: entity -> REST -> entity with both template types.""" + default = DefaultDeploymentTemplate(asset_id="azureml://registries/reg1/deploymenttemplates/dt1/versions/1") + allowed = [ + DefaultDeploymentTemplate(asset_id="azureml://registries/reg1/deploymenttemplates/dt1/versions/1"), + DefaultDeploymentTemplate(asset_id="azureml://registries/reg1/deploymenttemplates/dt2/versions/1"), + ] + original = Model( + name="test-model", + version="1", + path="./model.pkl", + description="Test model", + default_deployment_template=default, + allowed_deployment_templates=allowed, + ) + + rest_object = original._to_rest_object() + + # Verify REST object shape + assert isinstance(rest_object, ModelVersionData) + assert rest_object.properties.default_deployment_template.asset_id == default.asset_id + assert len(rest_object.properties.allowed_deployment_templates) == 2 + + # Now simulate deserialization from the REST object + # We can't do a full round-trip via _from_rest_object because it expects ARM IDs, + # but we can verify the REST serialization is correct + assert rest_object.properties.allowed_deployment_templates[0].asset_id == allowed[0].asset_id + assert rest_object.properties.allowed_deployment_templates[1].asset_id == allowed[1].asset_id