From de99e9d35f8122bedfc8c4d9fc1eb0d9e6dfbda8 Mon Sep 17 00:00:00 2001 From: Soroush Bassam Date: Mon, 9 Feb 2026 22:13:20 -0800 Subject: [PATCH 1/4] fix: align RL forward-backward loss --- openapi.yaml | 164 ++++++++++++++++++++++++++++++++++++++++++++++++--- 1 file changed, 156 insertions(+), 8 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index b066537..64ee8ea 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -7284,7 +7284,7 @@ components: type: object required: - samples - - loss_fn + - loss properties: samples: description: Batch of training samples to process @@ -7292,10 +7292,9 @@ components: items: type: object $ref: '#/components/schemas/RL.TrainingSample' - loss_fn: - description: Loss function to use for gradient computation - $ref: '#/components/schemas/RL.LossFn' - example: LOSS_FN_GRPO + loss: + description: Loss function configuration + $ref: '#/components/schemas/RL.LossConfig' RL.ModelInput: type: object required: @@ -7332,6 +7331,150 @@ components: - D_TYPE_FLOAT32 - D_TYPE_BFLOAT16 default: D_TYPE_UNSPECIFIED + RL.LossType: + type: string + enum: + - LOSS_TYPE_UNSPECIFIED + - LOSS_TYPE_CROSS_ENTROPY + - LOSS_TYPE_GRPO + default: LOSS_TYPE_UNSPECIFIED + RL.CrossEntropyLossParams: + type: object + description: Cross-entropy loss parameters (currently empty). + RL.GRPOLossParams: + type: object + properties: + clip_low: + type: number + format: float + example: 0.2 + description: Lower clip bound for importance ratio + clip_high: + type: number + format: float + example: 0.28 + description: Upper clip bound for importance ratio + beta: + type: number + format: float + example: 0.1 + description: KL penalty coefficient + agg_type: + type: string + example: fixed_horizon + description: Aggregation type for loss computation + RL.LossConfig: + type: object + required: + - type + properties: + type: + $ref: '#/components/schemas/RL.LossType' + description: Type of loss function to use + example: LOSS_TYPE_GRPO + cross_entropy_params: + $ref: '#/components/schemas/RL.CrossEntropyLossParams' + grpo_params: + $ref: '#/components/schemas/RL.GRPOLossParams' + RL.LossTargetTokens: + type: object + required: + - data + properties: + data: + description: Integer array of target tokens + type: array + example: + - 123 + - 456 + - 789 + items: + type: integer + dtype: + description: Data type of the integer array + $ref: '#/components/schemas/RL.DType' + example: D_TYPE_INT64 + RL.LossMask: + type: object + required: + - data + description: Per-token loss mask (1=compute loss, 0=ignore) + properties: + data: + description: Integer array of per-token mask values (0s and 1s) + type: array + example: + - 0 + - 0 + - 1 + items: + type: integer + dtype: + description: Data type of the integer array (must be D_TYPE_INT64) + $ref: '#/components/schemas/RL.DType' + example: D_TYPE_INT64 + RL.LossAdvantages: + type: object + required: + - data + properties: + data: + description: Float array of per-token advantages + type: array + example: + - 0.5 + - 0.5 + items: + type: number + format: float + dtype: + description: Data type of the float array (D_TYPE_FLOAT32 or D_TYPE_BFLOAT16) + $ref: '#/components/schemas/RL.DType' + example: D_TYPE_FLOAT32 + RL.LossLogprobs: + type: object + required: + - data + properties: + data: + description: Float array of per-token log probabilities + type: array + example: + - -1.2 + - -0.8 + items: + type: number + format: float + dtype: + description: Data type of the float array (D_TYPE_FLOAT32 or D_TYPE_BFLOAT16) + $ref: '#/components/schemas/RL.DType' + example: D_TYPE_FLOAT32 + RL.GRPOLossInputs: + type: object + required: + - advantages + - generator_logprobs + properties: + advantages: + $ref: '#/components/schemas/RL.LossAdvantages' + description: Per-token advantages for GRPO + generator_logprobs: + $ref: '#/components/schemas/RL.LossLogprobs' + description: Generator log probabilities for GRPO + reference_logprobs: + $ref: '#/components/schemas/RL.LossLogprobs' + description: Reference model log probabilities (required if beta > 0) + RL.LossInputs: + type: object + properties: + target_tokens: + $ref: '#/components/schemas/RL.LossTargetTokens' + description: Target tokens for loss computation (optional, defaults to shifted input_ids) + loss_mask: + $ref: '#/components/schemas/RL.LossMask' + description: Per-token loss mask (1=compute loss, 0=ignore) + grpo_inputs: + $ref: '#/components/schemas/RL.GRPOLossInputs' RL.LossFn: type: string default: LOSS_FN_UNSPECIFIED @@ -7391,14 +7534,14 @@ components: type: object required: - model_input - - loss_fn_inputs + - loss_inputs properties: model_input: description: Model input $ref: '#/components/schemas/RL.ModelInput' - loss_fn_inputs: + loss_inputs: description: Loss function inputs - $ref: '#/components/schemas/RL.LossFnInputs' + $ref: '#/components/schemas/RL.LossInputs' RL.OptimStepOperation: type: object properties: @@ -7441,6 +7584,11 @@ components: type: number example: 2.345 description: Loss value + metrics: + type: object + description: Loss-specific metrics (e.g., KL divergence, clip fraction for GRPO) + additionalProperties: + type: number RL.TrainingOperationError: type: object properties: From 2c15da57d9f1e421916accf5e700daa57e13b40d Mon Sep 17 00:00:00 2001 From: Soroush Bassam Date: Tue, 10 Feb 2026 14:53:54 -0800 Subject: [PATCH 2/4] add sampling --- openapi.yaml | 172 +++++++++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 172 insertions(+) diff --git a/openapi.yaml b/openapi.yaml index 64ee8ea..e85a0b7 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -7171,6 +7171,38 @@ paths: required: true schema: type: string + /rl/training-sessions/{session_id}/operations/sample/{operation_id}: + get: + summary: Get sample operation + description: Retrieves the current status and result of a sample operation. + operationId: GetSample + tags: [RL] + responses: + "200": + description: "" + content: + application/json: + schema: + $ref: '#/components/schemas/RL.SampleOperation' + default: + description: An unexpected error response. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + parameters: + - name: session_id + description: Training session ID + in: path + required: true + schema: + type: string + - name: operation_id + description: Operation ID + in: path + required: true + schema: + type: string /rl/training-sessions/{session_id}:forward-backward: post: summary: Forward-backward pass @@ -7235,6 +7267,38 @@ paths: required: true schema: type: string + /rl/training-sessions/{session_id}:sample: + post: + summary: Sample + description: Submits a sample operation that will asynchronously generate text completions with logprobs. + operationId: Sample + tags: [RL] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RL.SampleBody' + required: true + responses: + "200": + description: "" + content: + application/json: + schema: + $ref: '#/components/schemas/RL.SampleOperation' + default: + description: An unexpected error response. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + parameters: + - name: session_id + description: Training session ID + in: path + required: true + schema: + type: string /rl/training-sessions/{session_id}:stop: post: summary: Stop training session @@ -7280,6 +7344,23 @@ components: default: 0.0001 adamw_params: $ref: '#/components/schemas/RL.AdamWOptimizerParams' + RL.SampleBody: + type: object + required: + - prompt + properties: + prompt: + $ref: '#/components/schemas/RL.ModelInput' + description: Input prompt as tokenized chunks + sampling_params: + $ref: '#/components/schemas/RL.SamplingParams' + description: Optional sampling parameters + num_samples: + type: integer + format: int64 + example: 1 + default: 1 + description: Number of completions to generate for this prompt RL.ForwardBackwardBody: type: object required: @@ -7306,6 +7387,46 @@ components: items: type: object $ref: '#/components/schemas/RL.InputChunk' + RL.SamplingParams: + type: object + properties: + max_tokens: + type: integer + format: int32 + example: 100 + default: 100 + description: Maximum number of tokens to generate per completion + temperature: + type: number + format: float + example: 1 + default: 1.0 + description: Sampling temperature + top_p: + type: number + format: float + example: 1 + default: 1.0 + description: Nucleus sampling probability threshold + top_k: + type: integer + format: int32 + example: -1 + default: -1 + description: Top-k sampling limit + stop: + type: array + example: + - "\n" + - END + items: + type: string + description: Generation stops when any of these strings is produced + seed: + type: string + format: int64 + example: 42 + description: Random seed for reproducibility RL.InputChunk: type: object properties: @@ -7556,6 +7677,21 @@ components: $ref: '#/components/schemas/RL.OptimStepResult' error: $ref: '#/components/schemas/RL.TrainingOperationError' + RL.SampleOperation: + type: object + properties: + operation_id: + type: string + example: 550e8400-e29b-41d4-a716-446655440000 + description: Operation ID + status: + $ref: '#/components/schemas/RL.TrainingOperationStatus' + example: TRAINING_OPERATION_STATUS_PENDING + description: Operation status + data: + $ref: '#/components/schemas/RL.SampleResult' + error: + $ref: '#/components/schemas/RL.TrainingOperationError' RL.OptimStepResult: type: object properties: @@ -7563,6 +7699,42 @@ components: description: Step number type: integer example: 100 + RL.SampleResult: + type: object + properties: + sequences: + type: array + items: + type: object + $ref: '#/components/schemas/RL.SampleSequence' + description: Generated completions + RL.SampleSequence: + type: object + properties: + tokens: + type: array + example: + - 123 + - 456 + - 789 + items: + type: string + format: int64 + description: Generated token IDs + logprobs: + type: array + example: + - -0.5 + - -1.2 + - -0.3 + items: + type: number + format: double + description: Log probabilities for each generated token + stop_reason: + type: string + example: length + description: Reason for stopping generation RL.ForwardBackwardOperation: type: object properties: From 84596d29ef3f11919d35b880116bd0234540ed0a Mon Sep 17 00:00:00 2001 From: Blaine Kasten Date: Tue, 10 Feb 2026 17:19:44 -0600 Subject: [PATCH 3/4] some fixes --- openapi.yaml | 62 +++------------------------------------------------- 1 file changed, 3 insertions(+), 59 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index e85a0b7..aee5747 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -7462,6 +7462,7 @@ components: RL.CrossEntropyLossParams: type: object description: Cross-entropy loss parameters (currently empty). + properties: {} RL.GRPOLossParams: type: object properties: @@ -7596,61 +7597,6 @@ components: description: Per-token loss mask (1=compute loss, 0=ignore) grpo_inputs: $ref: '#/components/schemas/RL.GRPOLossInputs' - RL.LossFn: - type: string - default: LOSS_FN_UNSPECIFIED - enum: - - LOSS_FN_UNSPECIFIED - - LOSS_FN_GRPO - RL.LossFnInputs: - type: object - required: - - weights - - target_tokens - properties: - weights: - $ref: '#/components/schemas/RL.LossFnWeights' - description: Per-token weights - target_tokens: - $ref: '#/components/schemas/RL.LossFnTargetTokens' - description: Target tokens for loss computation - RL.LossFnTargetTokens: - type: object - required: - - data - properties: - data: - description: Integer array of target tokens - type: array - example: - - 123 - - 456 - - 789 - items: - type: number - dtype: - description: Data type of the integer array - $ref: '#/components/schemas/RL.DType' - example: D_TYPE_INT64 - RL.LossFnWeights: - type: object - required: - - data - properties: - data: - description: Float array of per-token weights - type: array - example: - - 0.1 - - 0.2 - - 0.3 - items: - type: number - format: float - dtype: - description: Data type of the float array - $ref: '#/components/schemas/RL.DType' - example: D_TYPE_FLOAT32 RL.TrainingSample: type: object required: @@ -7718,8 +7664,7 @@ components: - 456 - 789 items: - type: string - format: int64 + type: integer description: Generated token IDs logprobs: type: array @@ -7728,8 +7673,7 @@ components: - -1.2 - -0.3 items: - type: number - format: double + type: integer description: Log probabilities for each generated token stop_reason: type: string From dfc5517e4f589cdd3bd363594a3b68d4f9b69b19 Mon Sep 17 00:00:00 2001 From: Blaine Kasten Date: Tue, 10 Feb 2026 17:22:31 -0600 Subject: [PATCH 4/4] fixes --- openapi.yaml | 13 +------------ 1 file changed, 1 insertion(+), 12 deletions(-) diff --git a/openapi.yaml b/openapi.yaml index aee5747..eecfa30 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -7357,7 +7357,6 @@ components: description: Optional sampling parameters num_samples: type: integer - format: int64 example: 1 default: 1 description: Number of completions to generate for this prompt @@ -7392,25 +7391,21 @@ components: properties: max_tokens: type: integer - format: int32 example: 100 default: 100 description: Maximum number of tokens to generate per completion temperature: type: number - format: float example: 1 default: 1.0 description: Sampling temperature top_p: type: number - format: float example: 1 default: 1.0 description: Nucleus sampling probability threshold top_k: type: integer - format: int32 example: -1 default: -1 description: Top-k sampling limit @@ -7423,8 +7418,7 @@ components: type: string description: Generation stops when any of these strings is produced seed: - type: string - format: int64 + type: integer example: 42 description: Random seed for reproducibility RL.InputChunk: @@ -7468,17 +7462,14 @@ components: properties: clip_low: type: number - format: float example: 0.2 description: Lower clip bound for importance ratio clip_high: type: number - format: float example: 0.28 description: Upper clip bound for importance ratio beta: type: number - format: float example: 0.1 description: KL penalty coefficient agg_type: @@ -7548,7 +7539,6 @@ components: - 0.5 items: type: number - format: float dtype: description: Data type of the float array (D_TYPE_FLOAT32 or D_TYPE_BFLOAT16) $ref: '#/components/schemas/RL.DType' @@ -7566,7 +7556,6 @@ components: - -0.8 items: type: number - format: float dtype: description: Data type of the float array (D_TYPE_FLOAT32 or D_TYPE_BFLOAT16) $ref: '#/components/schemas/RL.DType'