Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions armis_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from armis_sdk.core.armis_sdk import ArmisSdk
from armis_sdk.core.client_credentials import ClientCredentials
from armis_sdk.core.armis_sdk import ArmisSdk # noqa: F401
from armis_sdk.core.client_credentials import ClientCredentials # noqa: F401
31 changes: 14 additions & 17 deletions armis_sdk/clients/assets_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
import datetime
from typing import AsyncIterator
from typing import Optional
from typing import Type
from typing import Union
from collections.abc import AsyncIterator

import universalasync

Expand Down Expand Up @@ -31,10 +28,10 @@ class AssetsClient(BaseEntityClient): # pylint: disable=too-few-public-methods

async def list_by_asset_id(
self,
asset_class: Type[AssetT],
asset_ids: Union[list[int], list[str]],
asset_class: type[AssetT],
asset_ids: list[int] | list[str],
asset_id_source: AssetIdSource = "ASSET_ID",
fields: Optional[list[str]] = None,
fields: list[str] | None = None,
) -> AsyncIterator[AssetT]:
"""List assets by asset ID or other identifiers.

Expand Down Expand Up @@ -81,9 +78,9 @@ async def main():

async def list_by_last_seen(
self,
asset_class: Type[AssetT],
last_seen: Union[datetime.datetime, datetime.timedelta],
fields: Optional[list[str]] = None,
asset_class: type[AssetT],
last_seen: datetime.datetime | datetime.timedelta,
fields: list[str] | None = None,
) -> AsyncIterator[AssetT]:
"""List assets by last seen timestamp.

Expand Down Expand Up @@ -120,7 +117,7 @@ async def main():
asyncio.run(main())
```
"""
filter_: dict[str, Union[str, int]] = {"filter_criteria": "LAST_SEEN"}
filter_: dict[str, str | int] = {"filter_criteria": "LAST_SEEN"}

if isinstance(last_seen, datetime.datetime):
filter_["last_seen_ge"] = last_seen.isoformat()
Expand All @@ -133,7 +130,7 @@ async def main():
yield item

async def list_fields(
self, asset_class: Type[AssetT]
self, asset_class: type[AssetT]
) -> AsyncIterator[AssetFieldDescription]:
"""List all available fields for a given asset class.

Expand Down Expand Up @@ -244,7 +241,7 @@ async def main():
def _create_bulk_update_request(
cls,
asset: Asset,
asset_id: Union[str, int],
asset_id: str | int,
field: str,
):
request = {"asset_id": asset_id, "key": field}
Expand All @@ -266,7 +263,7 @@ def _get_asset_id(
asset: Asset,
index: int,
asset_id_source: AssetIdSource,
) -> Union[str, int]:
) -> str | int:
if isinstance(asset, Device):
return cls._get_device_asset_id(asset, index, asset_id_source)

Expand Down Expand Up @@ -324,8 +321,8 @@ def _is_integration_field(cls, field: str) -> bool:

async def _list_assets(
self,
asset_class: Type[AssetT],
fields: Optional[list[str]],
asset_class: type[AssetT],
fields: list[str] | None,
filter_: dict,
) -> AsyncIterator[AssetT]:
fields = fields or sorted(asset_class.all_fields())
Expand Down Expand Up @@ -353,7 +350,7 @@ def _validate_asset_class(cls, assets: list[AssetT]):
@classmethod
def _validate_fields(
cls,
asset_class: Type[AssetT],
asset_class: type[AssetT],
fields: list[str],
allow_model_members=True,
):
Expand Down
15 changes: 7 additions & 8 deletions armis_sdk/clients/collectors_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import contextlib
from collections.abc import AsyncIterator
from collections.abc import Generator
from typing import IO
from typing import AsyncIterator
from typing import Generator
from typing import Union

import httpx
import universalasync
Expand All @@ -25,7 +24,7 @@ class CollectorsClient(BaseEntityClient):

async def download_image(
self,
destination: Union[str, IO[bytes]],
destination: str | IO[bytes],
image_type: CollectorImageType = "OVA",
) -> AsyncIterator[DownloadProgress]:
"""Download a collector image to a specified destination path / file.
Expand All @@ -48,12 +47,12 @@ async def main():
collectors_client = CollectorsClient()

# Download to a path
async for progress in armis_sdk.collectors.download_image("/tmp/collector.ova"):
async for progress in collectors_client.download_image("/tmp/collector.ova"):
print(progress.percent)

# Download to a file
with open("/tmp/collector.ova", "wb") as file:
async for progress in armis_sdk.collectors.download_image(file):
async for progress in collectors_client.download_image(file):
print(progress.percent)

asyncio.run(main())
Expand All @@ -67,7 +66,7 @@ async def main():
etc.
"""
collector_image = await self.get_image(image_type=image_type)
async with httpx.AsyncClient() as client:
async with httpx.AsyncClient() as client: # noqa: SIM117
async with client.stream("GET", collector_image.url) as response:
response.raise_for_status()
total_size = int(response.headers.get("Content-Length", "0"))
Expand Down Expand Up @@ -114,7 +113,7 @@ async def main():
@classmethod
@contextlib.contextmanager
def open_file(
cls, destination: Union[str, IO[bytes]]
cls, destination: str | IO[bytes]
) -> Generator[IO[bytes], None, None]:
if isinstance(destination, str):
with open(destination, "wb") as file:
Expand Down
13 changes: 6 additions & 7 deletions armis_sdk/clients/data_export_client.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import asyncio
from collections.abc import AsyncIterator
from typing import Any
from typing import AsyncIterator
from typing import Type

import pandas
import universalasync
Expand All @@ -17,7 +16,7 @@
@universalasync.wrap
class DataExportClient(BaseEntityClient):

async def disable(self, entity: Type[BaseExportedEntity]):
async def disable(self, entity: type[BaseExportedEntity]):
"""Disable data export of the entity.

Args:
Expand All @@ -40,7 +39,7 @@ async def main():
"""
await self.toggle(entity, False)

async def enable(self, entity: Type[BaseExportedEntity]):
async def enable(self, entity: type[BaseExportedEntity]):
"""Enable data export of the entity.

Args:
Expand All @@ -63,7 +62,7 @@ async def main():
"""
await self.toggle(entity, True)

async def iterate(self, entity: Type[T], **kwargs: Any) -> AsyncIterator[T]:
async def iterate(self, entity: type[T], **kwargs: Any) -> AsyncIterator[T]:
# pylint: disable=line-too-long
"""Iterate over the exported data.

Expand Down Expand Up @@ -136,7 +135,7 @@ async def main():
for _, row in data_frame.iterrows():
yield entity.series_to_model(row)

async def get(self, entity: Type[BaseExportedEntity]) -> DataExport:
async def get(self, entity: type[BaseExportedEntity]) -> DataExport:
"""Get the `DataExport` of the entity

Args:
Expand Down Expand Up @@ -169,7 +168,7 @@ async def main():
data = response_utils.get_data_dict(response)
return DataExport.model_validate(data)

async def toggle(self, entity: Type[BaseExportedEntity], enabled: bool):
async def toggle(self, entity: type[BaseExportedEntity], enabled: bool):
"""Enable / disable export of an entity.

Args:
Expand Down
2 changes: 1 addition & 1 deletion armis_sdk/clients/device_custom_properties_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import AsyncIterator
from collections.abc import AsyncIterator

import universalasync

Expand Down
5 changes: 2 additions & 3 deletions armis_sdk/clients/sites_client.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
from typing import AsyncIterator
from typing import List
from collections.abc import AsyncIterator

import universalasync

Expand Down Expand Up @@ -129,7 +128,7 @@ async def main():
data = response_utils.get_data_dict(response)
return Site.model_validate(data)

async def hierarchy(self) -> List[Site]:
async def hierarchy(self) -> list[Site]:
"""Create a hierarchy of the tenant's sites, taking into account the parent-child relationships.

Returns:
Expand Down
5 changes: 2 additions & 3 deletions armis_sdk/core/armis_auth.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import datetime
import typing
from typing import Optional

import httpx

Expand Down Expand Up @@ -28,8 +27,8 @@ class ArmisAuth(httpx.Auth):
def __init__(self, base_url: str, credentials: ClientCredentials):
self._base_url = base_url
self._credentials = credentials
self._access_token: Optional[str] = None
self._expires_at: Optional[datetime.datetime] = None
self._access_token: str | None = None
self._expires_at: datetime.datetime | None = None

def auth_flow(
self, request: httpx.Request
Expand Down
11 changes: 5 additions & 6 deletions armis_sdk/core/armis_client.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
import importlib.metadata
import os
import platform
from typing import AsyncIterator
from typing import Optional
from collections.abc import AsyncIterator
from typing import TypeVar

import httpx
Expand Down Expand Up @@ -48,7 +47,7 @@ class ArmisClient: # pylint: disable=too-few-public-methods
4. Proxy configuration via HTTPS_PROXY and HTTP_PROXY environment variables.
"""

def __init__(self, credentials: Optional[ClientCredentials] = None):
def __init__(self, credentials: ClientCredentials | None = None):
credentials = self._get_credentials(credentials)
self._auth = ArmisAuth(API_BASE_URL, credentials)
self._user_agent = " ".join(USER_AGENT_PARTS)
Expand All @@ -61,7 +60,7 @@ def __init__(self, credentials: Optional[ClientCredentials] = None):
except ValueError:
self._default_backoff = 0

def client(self, retries: Optional[int] = None, backoff: Optional[float] = None):
def client(self, retries: int | None = None, backoff: float | None = None):
retries = retries if retries is not None else self._default_retries
backoff = backoff if backoff is not None else self._default_backoff
retry = Retry(total=retries, backoff_factor=backoff)
Expand All @@ -82,7 +81,7 @@ def client(self, retries: Optional[int] = None, backoff: Optional[float] = None)
trust_env=True,
)

async def list(self, url: str, body: Optional[dict] = None) -> AsyncIterator[dict]:
async def list(self, url: str, body: dict | None = None) -> AsyncIterator[dict]:
"""List all items from a paginated endpoint.

Args:
Expand Down Expand Up @@ -131,7 +130,7 @@ async def main():

@classmethod
def _get_credentials(
cls, credentials: Optional[ClientCredentials]
cls, credentials: ClientCredentials | None
) -> ClientCredentials:
credentials = credentials or ClientCredentials()
credentials.vendor_id = credentials.vendor_id or os.getenv(ARMIS_VENDOR_ID)
Expand Down
9 changes: 3 additions & 6 deletions armis_sdk/core/armis_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,13 @@
"""

import json
from typing import List
from typing import Optional
from typing import Union

from httpx import HTTPStatusError
from pydantic import BaseModel


class DetailItem(BaseModel):
loc: list[Union[str, int]]
loc: list[str | int]
msg: str
type: str

Expand All @@ -26,7 +23,7 @@ def __str__(self):


class ErrorBody(BaseModel):
detail: Union[str, List[DetailItem]]
detail: str | list[DetailItem]


class ArmisError(Exception):
Expand Down Expand Up @@ -63,7 +60,7 @@ class ResponseError(ArmisError):
def __init__(
self,
error_body: ErrorBody,
response_errors: Optional[List[HTTPStatusError]] = None,
response_errors: list[HTTPStatusError] | None = None,
):
super().__init__(self._get_message(error_body))
self.response_errors = response_errors
Expand Down
4 changes: 1 addition & 3 deletions armis_sdk/core/armis_sdk.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
from typing import Optional

from armis_sdk.clients.assets_client import AssetsClient
from armis_sdk.clients.collectors_client import CollectorsClient
from armis_sdk.clients.data_export_client import DataExportClient
Expand Down Expand Up @@ -41,7 +39,7 @@ async def main():
```
"""

def __init__(self, credentials: Optional[ClientCredentials] = None):
def __init__(self, credentials: ClientCredentials | None = None):
self.client: ArmisClient = ArmisClient(credentials=credentials)
self.assets: AssetsClient = AssetsClient(self.client)
self.collectors: CollectorsClient = CollectorsClient(self.client)
Expand Down
8 changes: 3 additions & 5 deletions armis_sdk/core/base_entity_client.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,4 @@
from typing import AsyncIterator
from typing import Optional
from typing import Type
from collections.abc import AsyncIterator

import universalasync

Expand All @@ -10,12 +8,12 @@

class BaseEntityClient: # pylint: disable=too-few-public-methods

def __init__(self, armis_client: Optional[ArmisClient] = None) -> None:
def __init__(self, armis_client: ArmisClient | None = None) -> None:
self._armis_client = armis_client or ArmisClient()

@universalasync.async_to_sync_wraps
async def _list(
self, url: str, model: Type[BaseEntityT]
self, url: str, model: type[BaseEntityT]
) -> AsyncIterator[BaseEntityT]:
async for item in self._armis_client.list(url):
yield model.model_validate(item)
11 changes: 5 additions & 6 deletions armis_sdk/core/client_credentials.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
import dataclasses
from typing import Optional


@dataclasses.dataclass
class ClientCredentials:
audience: Optional[str] = None
client_id: Optional[str] = None
client_secret: Optional[str] = None
vendor_id: Optional[str] = None
scopes: Optional[list[str]] = None
audience: str | None = None
client_id: str | None = None
client_secret: str | None = None
vendor_id: str | None = None
scopes: list[str] | None = None
Loading
Loading