diff --git a/openapi.yaml b/openapi.yaml index b066537..eecfa30 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -7171,6 +7171,38 @@ paths: required: true schema: type: string + /rl/training-sessions/{session_id}/operations/sample/{operation_id}: + get: + summary: Get sample operation + description: Retrieves the current status and result of a sample operation. + operationId: GetSample + tags: [RL] + responses: + "200": + description: "" + content: + application/json: + schema: + $ref: '#/components/schemas/RL.SampleOperation' + default: + description: An unexpected error response. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + parameters: + - name: session_id + description: Training session ID + in: path + required: true + schema: + type: string + - name: operation_id + description: Operation ID + in: path + required: true + schema: + type: string /rl/training-sessions/{session_id}:forward-backward: post: summary: Forward-backward pass @@ -7235,6 +7267,38 @@ paths: required: true schema: type: string + /rl/training-sessions/{session_id}:sample: + post: + summary: Sample + description: Submits a sample operation that will asynchronously generate text completions with logprobs. + operationId: Sample + tags: [RL] + requestBody: + content: + application/json: + schema: + $ref: '#/components/schemas/RL.SampleBody' + required: true + responses: + "200": + description: "" + content: + application/json: + schema: + $ref: '#/components/schemas/RL.SampleOperation' + default: + description: An unexpected error response. + content: + application/json: + schema: + $ref: '#/components/schemas/ErrorResponse' + parameters: + - name: session_id + description: Training session ID + in: path + required: true + schema: + type: string /rl/training-sessions/{session_id}:stop: post: summary: Stop training session @@ -7280,11 +7344,27 @@ components: default: 0.0001 adamw_params: $ref: '#/components/schemas/RL.AdamWOptimizerParams' + RL.SampleBody: + type: object + required: + - prompt + properties: + prompt: + $ref: '#/components/schemas/RL.ModelInput' + description: Input prompt as tokenized chunks + sampling_params: + $ref: '#/components/schemas/RL.SamplingParams' + description: Optional sampling parameters + num_samples: + type: integer + example: 1 + default: 1 + description: Number of completions to generate for this prompt RL.ForwardBackwardBody: type: object required: - samples - - loss_fn + - loss properties: samples: description: Batch of training samples to process @@ -7292,10 +7372,9 @@ components: items: type: object $ref: '#/components/schemas/RL.TrainingSample' - loss_fn: - description: Loss function to use for gradient computation - $ref: '#/components/schemas/RL.LossFn' - example: LOSS_FN_GRPO + loss: + description: Loss function configuration + $ref: '#/components/schemas/RL.LossConfig' RL.ModelInput: type: object required: @@ -7307,6 +7386,41 @@ components: items: type: object $ref: '#/components/schemas/RL.InputChunk' + RL.SamplingParams: + type: object + properties: + max_tokens: + type: integer + example: 100 + default: 100 + description: Maximum number of tokens to generate per completion + temperature: + type: number + example: 1 + default: 1.0 + description: Sampling temperature + top_p: + type: number + example: 1 + default: 1.0 + description: Nucleus sampling probability threshold + top_k: + type: integer + example: -1 + default: -1 + description: Top-k sampling limit + stop: + type: array + example: + - "\n" + - END + items: + type: string + description: Generation stops when any of these strings is produced + seed: + type: integer + example: 42 + description: Random seed for reproducibility RL.InputChunk: type: object properties: @@ -7332,25 +7446,50 @@ components: - D_TYPE_FLOAT32 - D_TYPE_BFLOAT16 default: D_TYPE_UNSPECIFIED - RL.LossFn: + RL.LossType: type: string - default: LOSS_FN_UNSPECIFIED enum: - - LOSS_FN_UNSPECIFIED - - LOSS_FN_GRPO - RL.LossFnInputs: + - LOSS_TYPE_UNSPECIFIED + - LOSS_TYPE_CROSS_ENTROPY + - LOSS_TYPE_GRPO + default: LOSS_TYPE_UNSPECIFIED + RL.CrossEntropyLossParams: + type: object + description: Cross-entropy loss parameters (currently empty). + properties: {} + RL.GRPOLossParams: + type: object + properties: + clip_low: + type: number + example: 0.2 + description: Lower clip bound for importance ratio + clip_high: + type: number + example: 0.28 + description: Upper clip bound for importance ratio + beta: + type: number + example: 0.1 + description: KL penalty coefficient + agg_type: + type: string + example: fixed_horizon + description: Aggregation type for loss computation + RL.LossConfig: type: object required: - - weights - - target_tokens + - type properties: - weights: - $ref: '#/components/schemas/RL.LossFnWeights' - description: Per-token weights - target_tokens: - $ref: '#/components/schemas/RL.LossFnTargetTokens' - description: Target tokens for loss computation - RL.LossFnTargetTokens: + type: + $ref: '#/components/schemas/RL.LossType' + description: Type of loss function to use + example: LOSS_TYPE_GRPO + cross_entropy_params: + $ref: '#/components/schemas/RL.CrossEntropyLossParams' + grpo_params: + $ref: '#/components/schemas/RL.GRPOLossParams' + RL.LossTargetTokens: type: object required: - data @@ -7363,42 +7502,102 @@ components: - 456 - 789 items: - type: number + type: integer dtype: description: Data type of the integer array $ref: '#/components/schemas/RL.DType' example: D_TYPE_INT64 - RL.LossFnWeights: + RL.LossMask: type: object required: - data + description: Per-token loss mask (1=compute loss, 0=ignore) properties: data: - description: Float array of per-token weights + description: Integer array of per-token mask values (0s and 1s) type: array example: - - 0.1 - - 0.2 - - 0.3 + - 0 + - 0 + - 1 + items: + type: integer + dtype: + description: Data type of the integer array (must be D_TYPE_INT64) + $ref: '#/components/schemas/RL.DType' + example: D_TYPE_INT64 + RL.LossAdvantages: + type: object + required: + - data + properties: + data: + description: Float array of per-token advantages + type: array + example: + - 0.5 + - 0.5 + items: + type: number + dtype: + description: Data type of the float array (D_TYPE_FLOAT32 or D_TYPE_BFLOAT16) + $ref: '#/components/schemas/RL.DType' + example: D_TYPE_FLOAT32 + RL.LossLogprobs: + type: object + required: + - data + properties: + data: + description: Float array of per-token log probabilities + type: array + example: + - -1.2 + - -0.8 items: type: number - format: float dtype: - description: Data type of the float array + description: Data type of the float array (D_TYPE_FLOAT32 or D_TYPE_BFLOAT16) $ref: '#/components/schemas/RL.DType' example: D_TYPE_FLOAT32 + RL.GRPOLossInputs: + type: object + required: + - advantages + - generator_logprobs + properties: + advantages: + $ref: '#/components/schemas/RL.LossAdvantages' + description: Per-token advantages for GRPO + generator_logprobs: + $ref: '#/components/schemas/RL.LossLogprobs' + description: Generator log probabilities for GRPO + reference_logprobs: + $ref: '#/components/schemas/RL.LossLogprobs' + description: Reference model log probabilities (required if beta > 0) + RL.LossInputs: + type: object + properties: + target_tokens: + $ref: '#/components/schemas/RL.LossTargetTokens' + description: Target tokens for loss computation (optional, defaults to shifted input_ids) + loss_mask: + $ref: '#/components/schemas/RL.LossMask' + description: Per-token loss mask (1=compute loss, 0=ignore) + grpo_inputs: + $ref: '#/components/schemas/RL.GRPOLossInputs' RL.TrainingSample: type: object required: - model_input - - loss_fn_inputs + - loss_inputs properties: model_input: description: Model input $ref: '#/components/schemas/RL.ModelInput' - loss_fn_inputs: + loss_inputs: description: Loss function inputs - $ref: '#/components/schemas/RL.LossFnInputs' + $ref: '#/components/schemas/RL.LossInputs' RL.OptimStepOperation: type: object properties: @@ -7413,6 +7612,21 @@ components: $ref: '#/components/schemas/RL.OptimStepResult' error: $ref: '#/components/schemas/RL.TrainingOperationError' + RL.SampleOperation: + type: object + properties: + operation_id: + type: string + example: 550e8400-e29b-41d4-a716-446655440000 + description: Operation ID + status: + $ref: '#/components/schemas/RL.TrainingOperationStatus' + example: TRAINING_OPERATION_STATUS_PENDING + description: Operation status + data: + $ref: '#/components/schemas/RL.SampleResult' + error: + $ref: '#/components/schemas/RL.TrainingOperationError' RL.OptimStepResult: type: object properties: @@ -7420,6 +7634,40 @@ components: description: Step number type: integer example: 100 + RL.SampleResult: + type: object + properties: + sequences: + type: array + items: + type: object + $ref: '#/components/schemas/RL.SampleSequence' + description: Generated completions + RL.SampleSequence: + type: object + properties: + tokens: + type: array + example: + - 123 + - 456 + - 789 + items: + type: integer + description: Generated token IDs + logprobs: + type: array + example: + - -0.5 + - -1.2 + - -0.3 + items: + type: integer + description: Log probabilities for each generated token + stop_reason: + type: string + example: length + description: Reason for stopping generation RL.ForwardBackwardOperation: type: object properties: @@ -7441,6 +7689,11 @@ components: type: number example: 2.345 description: Loss value + metrics: + type: object + description: Loss-specific metrics (e.g., KL divergence, clip fraction for GRPO) + additionalProperties: + type: number RL.TrainingOperationError: type: object properties: