Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
# ---------------------------------------------------------
# Copyright (c) Microsoft Corporation. All rights reserved.
# ---------------------------------------------------------

# pylint: disable=unused-argument

import logging

from marshmallow import fields, post_load

from azure.ai.ml._schema.core.schema import PatchedSchemaMeta
from azure.ai.ml._utils._experimental import experimental

module_logger = logging.getLogger(__name__)


@experimental
class AcceleratorMapSchema(metaclass=PatchedSchemaMeta):
"""Schema for AcceleratorMap."""

accelerator_type = fields.Str(
required=True,
metadata={"description": "The type of accelerator (e.g. H100_80GB, H200_141GB, A100_80GB)."},
)
number_of_accelerators_per_model_instance = fields.Int(
required=True,
metadata={"description": "Number of accelerators per model instance."},
)
default = fields.Bool(
load_default=None,
metadata={"description": "Whether this is the default accelerator map."},
)

@post_load
def make(self, data, **kwargs):
from azure.ai.ml.entities._deployment.accelerator_map import AcceleratorMap

return AcceleratorMap(**data)
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,6 @@
from marshmallow import fields, post_load

from azure.ai.ml._schema.assets.environment import AnonymousEnvironmentSchema, EnvironmentSchema
from azure.ai.ml.constants._common import AzureMLResourceType
from azure.ai.ml._schema.core.fields import (
ArmVersionedStr,
NestedField,
Expand All @@ -21,7 +20,9 @@
VersionField,
)
from azure.ai.ml._utils._experimental import experimental
from azure.ai.ml.constants._common import AzureMLResourceType

from .accelerator_map_schema import AcceleratorMapSchema
from .probe_settings_schema import ProbeSettingsSchema
from .request_settings_schema import RequestSettingsSchema

Expand Down Expand Up @@ -56,6 +57,7 @@ class DeploymentTemplateSchema(PathAwareSchema):
)
scoring_port = fields.Int()
scoring_path = fields.Str()
accelerator_maps = fields.List(NestedField(AcceleratorMapSchema))

@post_load
def make(self, data, **kwargs): # pylint: disable=unused-argument
Expand Down
3 changes: 2 additions & 1 deletion sdk/ml/azure-ai-ml/azure/ai/ml/_schema/assets/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,9 @@

from marshmallow import fields, post_load, pre_dump

from azure.ai.ml._schema.assets.default_deployment_template import DefaultDeploymentTemplateSchema
from azure.ai.ml._schema.core.fields import ExperimentalField, NestedField
from azure.ai.ml._schema.core.intellectual_property import IntellectualPropertySchema
from azure.ai.ml._schema.assets.default_deployment_template import DefaultDeploymentTemplateSchema
from azure.ai.ml._schema.core.schema import PathAwareSchema
from azure.ai.ml._schema.job import CreationContextSchema
from azure.ai.ml.constants._common import BASE_PATH_CONTEXT_KEY, AssetTypes, AzureMLResourceType
Expand Down Expand Up @@ -46,6 +46,7 @@ class ModelSchema(PathAwareSchema):
intellectual_property = ExperimentalField(NestedField(IntellectualPropertySchema, required=False), dump_only=True)
system_metadata = fields.Dict()
default_deployment_template = NestedField(DefaultDeploymentTemplateSchema, required=False)
allowed_deployment_templates = fields.List(NestedField(DefaultDeploymentTemplateSchema), required=False)

@pre_dump
def validate(self, data, **kwargs):
Expand Down
98 changes: 21 additions & 77 deletions sdk/ml/azure-ai-ml/azure/ai/ml/entities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,12 +36,10 @@
from ._assets._artifacts.index import Index
from ._assets._artifacts.model import Model
from ._assets.asset import Asset
from ._assets.default_deployment_template import DefaultDeploymentTemplate
from ._assets.environment import BuildContext, Environment
from ._assets.intellectual_property import IntellectualProperty
from ._assets.workspace_asset_reference import (
WorkspaceAssetReference as WorkspaceModelReference,
)
from ._assets.default_deployment_template import DefaultDeploymentTemplate
from ._assets.workspace_asset_reference import WorkspaceAssetReference as WorkspaceModelReference
from ._autogen_entities.models import (
AzureOpenAIDeployment,
MarketplacePlan,
Expand All @@ -55,40 +53,19 @@
from ._component.pipeline_component import PipelineComponent
from ._component.spark_component import SparkComponent
from ._compute._aml_compute_node_info import AmlComputeNodeInfo
from ._compute._custom_applications import (
CustomApplications,
EndpointsSettings,
ImageSettings,
VolumeSettings,
)
from ._compute._custom_applications import CustomApplications, EndpointsSettings, ImageSettings, VolumeSettings
from ._compute._image_metadata import ImageMetadata
from ._compute._schedule import (
ComputePowerAction,
ComputeSchedules,
ComputeStartStopSchedule,
ScheduleState,
)
from ._compute._schedule import ComputePowerAction, ComputeSchedules, ComputeStartStopSchedule, ScheduleState
from ._compute._setup_scripts import ScriptReference, SetupScripts
from ._compute._usage import Usage, UsageName
from ._compute._vm_size import VmSize
from ._compute.aml_compute import AmlCompute, AmlComputeSshSettings
from ._compute.compute import Compute, NetworkSettings
from ._compute.compute_instance import (
AssignedUserConfiguration,
ComputeInstance,
ComputeInstanceSshSettings,
)
from ._compute.compute_instance import AssignedUserConfiguration, ComputeInstance, ComputeInstanceSshSettings
from ._compute.kubernetes_compute import KubernetesCompute
from ._compute.synapsespark_compute import (
AutoPauseSettings,
AutoScaleSettings,
SynapseSparkCompute,
)
from ._compute.synapsespark_compute import AutoPauseSettings, AutoScaleSettings, SynapseSparkCompute
from ._compute.unsupported_compute import UnsupportedCompute
from ._compute.virtual_machine_compute import (
VirtualMachineCompute,
VirtualMachineSshSettings,
)
from ._compute.virtual_machine_compute import VirtualMachineCompute, VirtualMachineSshSettings
from ._credentials import (
AadCredentialConfiguration,
AccessKeyConfiguration,
Expand All @@ -108,25 +85,18 @@
from ._data_import.data_import import DataImport
from ._data_import.schedule import ImportDataSchedule
from ._datastore.adls_gen1 import AzureDataLakeGen1Datastore
from ._datastore.azure_storage import (
AzureBlobDatastore,
AzureDataLakeGen2Datastore,
AzureFileDatastore,
)
from ._datastore.azure_storage import AzureBlobDatastore, AzureDataLakeGen2Datastore, AzureFileDatastore
from ._datastore.datastore import Datastore
from ._datastore.one_lake import OneLakeArtifact, OneLakeDatastore
from ._deployment.accelerator_map import AcceleratorMap
from ._deployment.batch_deployment import BatchDeployment
from ._deployment.batch_job import BatchJob
from ._deployment.code_configuration import CodeConfiguration
from ._deployment.container_resource_settings import ResourceSettings
from ._deployment.data_asset import DataAsset
from ._deployment.data_collector import DataCollector
from ._deployment.deployment_collection import DeploymentCollection
from ._deployment.deployment_settings import (
BatchRetrySettings,
OnlineRequestSettings,
ProbeSettings,
)
from ._deployment.deployment_settings import BatchRetrySettings, OnlineRequestSettings, ProbeSettings
from ._deployment.deployment_template import DeploymentTemplate
from ._deployment.model_batch_deployment import ModelBatchDeployment
from ._deployment.model_batch_deployment_settings import ModelBatchDeploymentSettings
Expand All @@ -136,16 +106,10 @@
ManagedOnlineDeployment,
OnlineDeployment,
)
from ._deployment.pipeline_component_batch_deployment import (
PipelineComponentBatchDeployment,
)
from ._deployment.pipeline_component_batch_deployment import PipelineComponentBatchDeployment
from ._deployment.request_logging import RequestLogging
from ._deployment.resource_requirements_settings import ResourceRequirementsSettings
from ._deployment.scale_settings import (
DefaultScaleSettings,
OnlineScaleSettings,
TargetUtilizationScaleSettings,
)
from ._deployment.scale_settings import DefaultScaleSettings, OnlineScaleSettings, TargetUtilizationScaleSettings
from ._endpoint.batch_endpoint import BatchEndpoint
from ._endpoint.endpoint import Endpoint
from ._endpoint.online_endpoint import (
Expand All @@ -160,14 +124,10 @@
from ._feature_set.feature import Feature
from ._feature_set.feature_set_backfill_metadata import FeatureSetBackfillMetadata
from ._feature_set.feature_set_backfill_request import FeatureSetBackfillRequest
from ._feature_set.feature_set_materialization_metadata import (
FeatureSetMaterializationMetadata,
)
from ._feature_set.feature_set_materialization_metadata import FeatureSetMaterializationMetadata
from ._feature_set.feature_set_specification import FeatureSetSpecification
from ._feature_set.feature_window import FeatureWindow
from ._feature_set.materialization_compute_resource import (
MaterializationComputeResource,
)
from ._feature_set.materialization_compute_resource import MaterializationComputeResource
from ._feature_set.materialization_settings import MaterializationSettings
from ._feature_set.materialization_type import MaterializationType
from ._feature_store.feature_store import FeatureStore
Expand All @@ -183,15 +143,9 @@
from ._job.input_port import InputPort
from ._job.job import Job
from ._job.job_limits import CommandJobLimits
from ._job.job_resources import JobResources
from ._job.job_resource_configuration import JobResourceConfiguration
from ._job.job_service import (
JobService,
JupyterLabJobService,
SshJobService,
TensorBoardJobService,
VsCodeJobService,
)
from ._job.job_resources import JobResources
from ._job.job_service import JobService, JupyterLabJobService, SshJobService, TensorBoardJobService, VsCodeJobService
from ._job.parallel.parallel_task import ParallelTask
from ._job.parallel.retry_settings import RetrySettings
from ._job.parameterized_command import ParameterizedCommand
Expand All @@ -207,12 +161,7 @@
from ._monitoring.alert_notification import AlertNotification
from ._monitoring.compute import ServerlessSparkCompute
from ._monitoring.definition import MonitorDefinition
from ._monitoring.input_data import (
FixedInputData,
MonitorInputData,
StaticInputData,
TrailingInputData,
)
from ._monitoring.input_data import FixedInputData, MonitorInputData, StaticInputData, TrailingInputData
from ._monitoring.schedule import MonitorSchedule
from ._monitoring.signals import (
BaselineDataRange,
Expand Down Expand Up @@ -260,6 +209,7 @@
from ._schedule.trigger import CronTrigger, RecurrencePattern, RecurrenceTrigger
from ._system_data import SystemData
from ._validation import ValidationResult
from ._workspace._ai_workspaces.capability_host import CapabilityHost, CapabilityHostKind
from ._workspace._ai_workspaces.hub import Hub
from ._workspace._ai_workspaces.project import Project
from ._workspace.compute_runtime import ComputeRuntime
Expand Down Expand Up @@ -300,15 +250,7 @@
from ._workspace.private_endpoint import EndpointConnection, PrivateEndpoint
from ._workspace.serverless_compute import ServerlessComputeSettings
from ._workspace.workspace import Workspace
from ._workspace._ai_workspaces.capability_host import (
CapabilityHost,
CapabilityHostKind,
)
from ._workspace.workspace_keys import (
ContainerRegistryCredential,
NotebookAccessKeys,
WorkspaceKeys,
)
from ._workspace.workspace_keys import ContainerRegistryCredential, NotebookAccessKeys, WorkspaceKeys

__all__ = [
"Resource",
Expand Down Expand Up @@ -341,6 +283,7 @@
"Deployment",
"BatchDeployment",
"DeploymentTemplate",
"AcceleratorMap",
"BatchJob",
"CodeConfiguration",
"Endpoint",
Expand Down Expand Up @@ -573,6 +516,7 @@
"LocalSource",
"IndexModelConfiguration",
"DefaultDeploymentTemplate",
"AcceleratorMap",
]

# Allow importing these types for backwards compatibility
Expand Down
Loading
Loading