Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
924 changes: 924 additions & 0 deletions notebooks/how_to/qualitative_text/qualitative_text_generation.ipynb

Large diffs are not rendered by default.

76 changes: 75 additions & 1 deletion tests/test_api_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,11 @@
import validmind.api_client as api_client
from validmind.__version__ import __version__
from validmind.errors import (
APIRequestError,
MissingAPICredentialsError,
MissingModelIdError,
APIRequestError,
)
from validmind.utils import md_to_html
from validmind.vm_models.figure import Figure


Expand Down Expand Up @@ -245,6 +246,79 @@ def test_log_test_result(self, mock_post):

mock_post.assert_called_with(url, data=json.dumps(result))

@patch("requests.post")
@patch("aiohttp.ClientSession.post")
def test_log_text_generates_text_and_logs_metadata(
self, mock_aiohttp_post, mock_requests_post
):
mock_requests_post.return_value = Mock(status_code=200)
mock_requests_post.return_value.json.return_value = {
"content": "## Generated Summary\nGenerated content."
}
mock_aiohttp_post.return_value = MockAsyncResponse(
200,
json={
"content_id": "dataset_summary_text",
"text": md_to_html("## Generated Summary\nGenerated content.", mathml=True),
},
)

api_client.log_text(
content_id="dataset_summary_text",
prompt="Summarize the dataset.",
context={"content_ids": ["train_dataset", "target_description_text"]},
)

mock_requests_post.assert_called_once_with(
url=f"{os.environ['VM_API_HOST']}/ai/generate/qualitative_text_generation",
headers={
"X-API-KEY": os.environ["VM_API_KEY"],
"X-API-SECRET": os.environ["VM_API_SECRET"],
"X-MODEL-CUID": os.environ["VM_API_MODEL"],
"X-MONITORING": "False",
"X-LIBRARY-VERSION": __version__,
},
json={
"content_id": "dataset_summary_text",
"generate": True,
"prompt": "Summarize the dataset.",
"context": {
"content_ids": ["train_dataset", "target_description_text"]
},
},
)
mock_aiohttp_post.assert_called_once_with(
f"{os.environ['VM_API_HOST']}/log_metadata",
data=json.dumps(
{
"content_id": "dataset_summary_text",
"text": md_to_html(
"## Generated Summary\nGenerated content.", mathml=True
),
}
),
)

def test_log_text_rejects_prompt_when_text_is_provided(self):
with self.assertRaisesRegex(
ValueError, "`prompt` is only supported when `text` is omitted"
):
api_client.log_text(
content_id="dataset_summary_text",
text="Hello world",
prompt="Ignore the provided text.",
)

def test_log_text_rejects_invalid_context(self):
with self.assertRaisesRegex(
ValueError,
"`context\\['content_ids'\\]` must contain only non-empty strings",
):
api_client.log_text(
content_id="dataset_summary_text",
context={"content_ids": ["valid", ""]},
)


if __name__ == "__main__":
unittest.main()
71 changes: 70 additions & 1 deletion tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,15 @@

import validmind
from validmind import (
get_content_ids,
init_dataset,
init_model,
get_test_suite,
run_text_generation,
run_documentation_tests,
)
from validmind.errors import UnsupportedModelError
from validmind.vm_models.result import TextGenerationResult


@dataclass
Expand Down Expand Up @@ -115,7 +118,7 @@ def test_init_model_invalid_metadata_dict(self):
"key": "value",
"foo": "bar",
}
with self.assertRaises(UnsupportedModelError) as context:
with self.assertRaises(UnsupportedModelError):
init_model(attributes=metadata, __log=False)

def test_init_model_metadata_dict(self):
Expand Down Expand Up @@ -163,6 +166,72 @@ def test_get_default_config(self):
self.assertIn("params", config)


class TestGetContentIds(TestCase):
@mock.patch(
"validmind.client_config.client_config.documentation_template",
MockedConfig.documentation_template,
)
def test_get_all_content_ids(self):
content_ids = get_content_ids()
self.assertEqual(
content_ids,
[
"validmind.data_validation.ClassImbalance",
"validmind.data_validation.DatasetSplit",
],
)

@mock.patch(
"validmind.client_config.client_config.documentation_template",
MockedConfig.documentation_template,
)
def test_get_content_ids_for_single_section(self):
content_ids = get_content_ids("test_section_1")
self.assertEqual(content_ids, ["validmind.data_validation.ClassImbalance"])

@mock.patch(
"validmind.client_config.client_config.documentation_template",
MockedConfig.documentation_template,
)
def test_get_content_ids_for_multiple_sections(self):
content_ids = get_content_ids(["test_section_1", "test_section_2"])
self.assertEqual(
content_ids,
[
"validmind.data_validation.ClassImbalance",
"validmind.data_validation.DatasetSplit",
],
)


class TestRunTextGeneration(TestCase):
@mock.patch(
"validmind.client.api_client._generate_log_text",
return_value="<p>Generated text</p>",
)
def test_run_text_generation(self, mock_generate_text):
result = run_text_generation(
content_id="dataset_summary_text",
prompt="Summarize the dataset.",
context={"content_ids": ["train_dataset"]},
show=False,
)

self.assertIsInstance(result, TextGenerationResult)
self.assertEqual(result.content_id, "dataset_summary_text")
self.assertEqual(result.prompt, "Summarize the dataset.")
self.assertEqual(result.context, {"content_ids": ["train_dataset"]})
self.assertEqual(result.description, "<p>Generated text</p>")
self.assertIn("validmind", result.metadata)
self.assertIn("timestamp", result.metadata)
self.assertIn("duration_seconds", result.metadata)
mock_generate_text.assert_called_once_with(
"dataset_summary_text",
"Summarize the dataset.",
{"content_ids": ["train_dataset"]},
)


# TODO: Fix this test
# class TestPreviewTemplate(TestCase):
# @mock.patch(
Expand Down
30 changes: 30 additions & 0 deletions tests/test_results.py
Original file line number Diff line number Diff line change
Expand Up @@ -181,11 +181,41 @@ def test_text_generation_result(self):
self.assertEqual(text_result.name, "Text Generation Result")
self.assertEqual(text_result.title, "Text Test")
self.assertEqual(text_result.description, "Generated text")
self.assertIsNone(text_result.doc)
self.assertIsNone(text_result.test_name)

html = text_result.to_html()
self.assertIsInstance(html, str)
self.assertIn("Generated text", html)

@patch("validmind.vm_models.result.result.api_client.alog_text")
async def test_text_generation_result_log_async(self, mock_log_text):
"""Test async logging of TextGenerationResult through alog_text"""
text_result = TextGenerationResult(
result_id="text_1",
content_id="dataset_summary_text",
description="Generated text",
)

await text_result.log_async()

mock_log_text.assert_called_once_with(
content_id="dataset_summary_text",
text="Generated text",
)

async def test_text_generation_result_log_async_requires_content_id(self):
"""Test TextGenerationResult requires a content_id when logging"""
text_result = TextGenerationResult(
result_id="text_1",
description="Generated text",
)

with self.assertRaisesRegex(
ValueError, "`content_id` must be provided to log generated text"
):
await text_result.log_async()

def test_validate_log_config(self):
"""Test validation of log configuration"""
test_result = TestResult(result_id="test_1")
Expand Down
4 changes: 4 additions & 0 deletions validmind/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,13 +52,15 @@
from .__version__ import __version__ # noqa: E402
from .api_client import init, log_metric, log_test_result, log_text, reload
from .client import ( # noqa: E402
get_content_ids,
get_test_suite,
init_dataset,
init_model,
init_r_model,
preview_template,
run_documentation_tests,
run_test_suite,
run_text_generation,
)
from .experimental import agents as experimental_agent
from .tests.decorator import scorer as scorer_decorator
Expand Down Expand Up @@ -117,10 +119,12 @@ def check_version():
"init_model",
"init_r_model",
"get_test_suite",
"get_content_ids",
"log_metric",
"preview_template",
"print_env",
"reload",
"run_text_generation",
"run_documentation_tests",
# log metric function (for direct/bulk/retroactive logging of metrics)
# test suite functions (less common)
Expand Down
Loading
Loading