Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/Contracts/DecoderInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,14 @@ interface DecoderInterface
* @param string[] $tokens
*/
public function decode(array $tokens): string;

/**
* Get configuration value(s).
*
* @param null|string $key The configuration key. If null, returns all config.
* @param mixed $default The default value if the key doesn't exist
*
* @return mixed the configuration value, or full config array if $key is null
*/
public function getConfig(?string $key = null, mixed $default = null): mixed;
}
10 changes: 6 additions & 4 deletions src/Contracts/ModelInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -54,10 +54,12 @@ public function getVocabSize(): int;
public function addToken(string $token, int $id): void;

/**
* Get the end of word suffix, if any.
* Only some models (like BPE) have this property.
* Get configuration value(s).
*
* @return null|string the end of word suffix
* @param null|string $key The configuration key (e.g., 'dropout'). If null, returns all config.
* @param mixed $default The default value if the key doesn't exist (ignored when $key is null)
*
* @return mixed the configuration value, or full config array if $key is null
*/
public function getEndOfWordSuffix(): ?string;
public function getConfig(?string $key = null, mixed $default = null): mixed;
}
10 changes: 10 additions & 0 deletions src/Contracts/NormalizerInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,14 @@
interface NormalizerInterface
{
public function normalize(string $text): string;

/**
* Get configuration value(s).
*
* @param null|string $key The configuration key. If null, returns all config.
* @param mixed $default The default value if the key doesn't exist
*
* @return mixed the configuration value, or full config array if $key is null
*/
public function getConfig(?string $key = null, mixed $default = null): mixed;
}
10 changes: 10 additions & 0 deletions src/Contracts/PostProcessorInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,14 @@ interface PostProcessorInterface
* @return array{0: string[], 1: int[]} the processed tokens and type IDs
*/
public function process(array $tokens, ?array $pair = null, bool $addSpecialTokens = true): array;

/**
* Get configuration value(s).
*
* @param null|string $key The configuration key. If null, returns all config.
* @param mixed $default The default value if the key doesn't exist
*
* @return mixed the configuration value, or full config array if $key is null
*/
public function getConfig(?string $key = null, mixed $default = null): mixed;
}
10 changes: 10 additions & 0 deletions src/Contracts/PreTokenizerInterface.php
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,14 @@ interface PreTokenizerInterface
* @return string[]
*/
public function preTokenize(array|string $text, array $options = []): array;

/**
* Get configuration value(s).
*
* @param null|string $key The configuration key. If null, returns all config.
* @param mixed $default The default value if the key doesn't exist
*
* @return mixed the configuration value, or full config array if $key is null
*/
public function getConfig(?string $key = null, mixed $default = null): mixed;
}
18 changes: 17 additions & 1 deletion src/DataStructures/AddedToken.php
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
* - Whether they should only match single words
* - Whether to include any whitespace on its left or right.
*/
readonly class AddedToken
class AddedToken implements \JsonSerializable
{
public function __construct(
/**
Expand Down Expand Up @@ -58,4 +58,20 @@ public static function fromArray(array $data): self
$data['special'] ?? false,
);
}

/**
* @return array<string, mixed>
*/
public function jsonSerialize(): array
{
return [
'id' => $this->id,
'content' => $this->content,
'single_word' => $this->singleWord,
'lstrip' => $this->lStrip,
'rstrip' => $this->rStrip,
'normalized' => $this->normalized,
'special' => $this->special,
];
}
}
16 changes: 16 additions & 0 deletions src/Decoders/BPEDecoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,22 @@ class BPEDecoder extends BaseDecoder
{
public function __construct(protected string $suffix = '') {}

public function getConfig(?string $key = null, mixed $default = null): mixed
{
if (null !== $key) {
return match ($key) {
'type' => 'BPEDecoder',
'suffix' => $this->suffix,
default => $default,
};
}

return [
'type' => 'BPEDecoder',
'suffix' => $this->suffix,
];
}

protected function processTokens(array $tokens): array
{
$decoded = [];
Expand Down
9 changes: 9 additions & 0 deletions src/Decoders/ByteFallbackDecoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@

class ByteFallbackDecoder extends BaseDecoder
{
public function getConfig(?string $key = null, mixed $default = null): mixed
{
if (null !== $key) {
return 'type' === $key ? 'ByteFallback' : $default;
}

return ['type' => 'ByteFallback'];
}

protected function processTokens(array $tokens): array
{
$newTokens = [];
Expand Down
13 changes: 9 additions & 4 deletions src/Decoders/ByteLevelDecoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -275,11 +275,16 @@ public function __construct(protected array $addedTokens = [], protected ?string

/**
* Convert an array of tokens to a string by decoding each byte.
*
* @param string[] $tokens array of tokens to be decoded
*
* @return string the decoded string
*/
public function getConfig(?string $key = null, mixed $default = null): mixed
{
if (null !== $key) {
return 'type' === $key ? 'ByteLevel' : $default;
}

return ['type' => 'ByteLevel'];
}

public function decode(array $tokens): string
{
$decoded = parent::decode($tokens);
Expand Down
20 changes: 20 additions & 0 deletions src/Decoders/CTCDecoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,26 @@ public function __construct(
protected bool $cleanup = true
) {}

public function getConfig(?string $key = null, mixed $default = null): mixed
{
if (null !== $key) {
return match ($key) {
'type' => 'CTC',
'pad_token' => $this->padToken,
'word_delimiter_token' => $this->wordDelimiterToken,
'cleanup' => $this->cleanup,
default => $default,
};
}

return [
'type' => 'CTC',
'pad_token' => $this->padToken,
'word_delimiter_token' => $this->wordDelimiterToken,
'cleanup' => $this->cleanup,
];
}

protected function processTokens(array $tokens): array
{
if (empty($tokens)) {
Expand Down
16 changes: 16 additions & 0 deletions src/Decoders/DecoderSequence.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,22 @@ class DecoderSequence extends BaseDecoder
*/
public function __construct(protected array $decoders) {}

public function getConfig(?string $key = null, mixed $default = null): mixed
{
if (null !== $key) {
return match ($key) {
'type' => 'Sequence',
'decoders' => array_map(static fn (BaseDecoder $d) => $d->getConfig(), $this->decoders),
default => $default,
};
}

return [
'type' => 'Sequence',
'decoders' => array_map(static fn (BaseDecoder $d) => $d->getConfig(), $this->decoders),
];
}

protected function processTokens(array $tokens): array
{
return array_reduce(
Expand Down
16 changes: 16 additions & 0 deletions src/Decoders/FuseDecoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,22 @@ public function __construct(
protected string $separator = ''
) {}

public function getConfig(?string $key = null, mixed $default = null): mixed
{
if (null !== $key) {
return match ($key) {
'type' => 'Fuse',
'separator' => $this->separator,
default => $default,
};
}

return [
'type' => 'Fuse',
'separator' => $this->separator,
];
}

protected function processTokens(array $tokens): array
{
return [implode($this->separator, $tokens)];
Expand Down
18 changes: 18 additions & 0 deletions src/Decoders/MetaspaceDecoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,24 @@ public function __construct(
protected bool $addPrefixSpace = true
) {}

public function getConfig(?string $key = null, mixed $default = null): mixed
{
if (null !== $key) {
return match ($key) {
'type' => 'Metaspace',
'replacement' => $this->replacement,
'add_prefix_space' => $this->addPrefixSpace,
default => $default,
};
}

return [
'type' => 'Metaspace',
'replacement' => $this->replacement,
'add_prefix_space' => $this->addPrefixSpace,
];
}

protected function processTokens(array $tokens): array
{
$result = [];
Expand Down
18 changes: 18 additions & 0 deletions src/Decoders/ReplaceDecoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,24 @@ public function __construct(
protected string $replacement = ''
) {}

public function getConfig(?string $key = null, mixed $default = null): mixed
{
if (null !== $key) {
return match ($key) {
'type' => 'Replace',
'pattern' => $this->regex ? ['Regex' => $this->regex] : ['String' => $this->subString],
'content' => $this->replacement,
default => $default,
};
}

return [
'type' => 'Replace',
'pattern' => $this->regex ? ['Regex' => $this->regex] : ['String' => $this->subString],
'content' => $this->replacement,
];
}

protected function processTokens(array $tokens): array
{
return array_map(function ($token) {
Expand Down
20 changes: 20 additions & 0 deletions src/Decoders/StripDecoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,26 @@ public function __construct(
protected int $stop
) {}

public function getConfig(?string $key = null, mixed $default = null): mixed
{
if (null !== $key) {
return match ($key) {
'type' => 'Strip',
'content' => $this->content,
'start' => $this->start,
'stop' => $this->stop,
default => $default,
};
}

return [
'type' => 'Strip',
'content' => $this->content,
'start' => $this->start,
'stop' => $this->stop,
];
}

protected function processTokens(array $tokens): array
{
return array_map(function ($token) {
Expand Down
18 changes: 18 additions & 0 deletions src/Decoders/WordPieceDecoder.php
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,24 @@ public function __construct(
protected bool $cleanup = true
) {}

public function getConfig(?string $key = null, mixed $default = null): mixed
{
if (null !== $key) {
return match ($key) {
'type' => 'WordPiece',
'prefix' => $this->prefix,
'cleanup' => $this->cleanup,
default => $default,
};
}

return [
'type' => 'WordPiece',
'prefix' => $this->prefix,
'cleanup' => $this->cleanup,
];
}

protected function processTokens(array $tokens): array
{
$decodedTokens = [];
Expand Down
4 changes: 2 additions & 2 deletions src/Factories/DecoderFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,10 @@ class DecoderFactory
* @param array<string, AddedToken> $addedTokens Optional. Only needed for ByteLevelDecoder.
* @param null|string $endOfWordSuffix Optional. Only needed for ByteLevelDecoder.
*/
public static function create(array $config, array $addedTokens = [], ?string $endOfWordSuffix = null): ?DecoderInterface
public static function create(array $config, array $addedTokens = [], ?string $endOfWordSuffix = null): DecoderInterface
{
if (empty($config)) {
return null;
return new FuseDecoder(' ');
}

$type = $config['type'] ?? null;
Expand Down
5 changes: 5 additions & 0 deletions src/Factories/NormalizerFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
use Codewithkyrian\Tokenizers\Normalizers\NFKCNormalizer;
use Codewithkyrian\Tokenizers\Normalizers\NFKDNormalizer;
use Codewithkyrian\Tokenizers\Normalizers\NormalizerSequence;
use Codewithkyrian\Tokenizers\Normalizers\PassThroughNormalizer;
use Codewithkyrian\Tokenizers\Normalizers\PrecompiledNormalizer;
use Codewithkyrian\Tokenizers\Normalizers\PrependNormalizer;
use Codewithkyrian\Tokenizers\Normalizers\ReplaceNormalizer;
Expand All @@ -24,6 +25,10 @@ class NormalizerFactory
*/
public static function create(array $config): NormalizerInterface
{
if (empty($config)) {
return new PassThroughNormalizer();
}

$type = $config['type'] ?? null;

return match ($type) {
Expand Down
5 changes: 3 additions & 2 deletions src/Factories/PostProcessorFactory.php
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
use Codewithkyrian\Tokenizers\Contracts\PostProcessorInterface;
use Codewithkyrian\Tokenizers\PostProcessors\BertPostProcessor;
use Codewithkyrian\Tokenizers\PostProcessors\ByteLevelPostProcessor;
use Codewithkyrian\Tokenizers\PostProcessors\DefaultPostProcessor;
use Codewithkyrian\Tokenizers\PostProcessors\PostProcessorSequence;
use Codewithkyrian\Tokenizers\PostProcessors\RobertaPostProcessor;
use Codewithkyrian\Tokenizers\PostProcessors\TemplatePostProcessor;
Expand All @@ -16,10 +17,10 @@ class PostProcessorFactory
/**
* @param array<string, mixed> $config the post-processor configuration
*/
public static function create(array $config): ?PostProcessorInterface
public static function create(array $config): PostProcessorInterface
{
if (empty($config)) {
return null;
return new DefaultPostProcessor();
}

$type = $config['type'] ?? null;
Expand Down
Loading