diff --git a/openapi.yaml b/openapi.yaml index 8765caf..68b2b4e 100644 --- a/openapi.yaml +++ b/openapi.yaml @@ -6488,6 +6488,179 @@ paths: } ``` + /rl/sessions/start_session: + post: + operationId: sessionStart + summary: Create training session + x-codeSamples: + - lang: Python + label: Python API Client + source: | + from together import Together + client = Together() + + session = client.beta.rl.sessions.create_session() + print(session) + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/RL.SessionCreateParams' + responses: + '200': + description: Created training session and initial LoRA adapter + content: + application/json: + schema: + $ref: '#/components/schemas/RL.SessionStartResponse' + /rl/retrieve_future: + post: + summary: Retrieve future result + description: Retrieves the result of a future by its ID + operationId: retrieveFuture + x-codeSamples: + - lang: Python + label: Python WebSocket Client + source: | + from together import Together + client = Together() + + future = client.beta.rl.futures.retrieve() + print(future) + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/RL.UntypedAPIFuture' + responses: + '200': + description: Future result + content: + application/json: + schema: + oneOf: + - $ref: '#/components/schemas/RL.TryAgainResponse' + - $ref: '#/components/schemas/RL.ForwardBackwardOutput' + - $ref: '#/components/schemas/RL.OptimStepResponse' + - $ref: '#/components/schemas/RL.SaveTrainingStateResponse' + '400': + description: Bad request (invalid future ID) + content: + application/json: + schema: + $ref: '#/components/schemas/RL.ErrorResponse' + '404': + description: Future not found + content: + application/json: + schema: + $ref: '#/components/schemas/RL.ErrorResponse' + '408': + description: Timeout, try again + content: + application/json: + schema: + $ref: '#/components/schemas/RL.TryAgainResponse' + /rl/forward: + post: + summary: Forward pass + description: Performs a forward pass through the model + operationId: forward + x-codeSamples: + - lang: Python + label: Python API Client + source: | + from together import Together + client = Together() + + future = client.beta.rl.training.forward() + print(future) + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - training_session_id + - forward_input + properties: + training_session_id: + $ref: '#/components/schemas/RL.TrainingId' + forward_input: + $ref: '#/components/schemas/RL.ForwardBackwardInput' + responses: + '200': + description: API future for forward completion + content: + application/json: + schema: + $ref: '#/components/schemas/RL.UntypedAPIFuture' + /rl/forward_backward: + post: + summary: Forward and backward pass + description: Performs a forward and backward pass through the model + operationId: forwardBackward + x-codeSamples: + - lang: Python + label: Python API Client + source: | + from together import Together + client = Together() + + future = client.beta.rl.training.forward_backward() + print(future) + requestBody: + required: true + content: + application/json: + schema: + type: object + required: + - training_session_id + - forward_backward_input + properties: + training_session_id: + $ref: '#/components/schemas/RL.TrainingId' + forward_backward_input: + $ref: '#/components/schemas/RL.ForwardBackwardInput' + responses: + '200': + description: API future for forward/backward completion + content: + application/json: + schema: + $ref: '#/components/schemas/RL.UntypedAPIFuture' + /rl/optim_step: + post: + summary: Optimization step + description: Performs an optimization step using AdamW optimizer + operationId: optimStep + x-codeSamples: + - lang: Python + label: Python API Client + source: | + from together import Together + client = Together() + + future = client.beta.rl.training.optim_step() + print(future) + requestBody: + required: true + content: + application/json: + schema: + $ref: '#/components/schemas/RL.OptimStepRequest' + responses: + '200': + description: API future for optimization step completion + content: + application/json: + schema: + $ref: '#/components/schemas/RL.UntypedAPIFuture' + components: securitySchemes: bearerAuth: @@ -11402,3 +11575,377 @@ components: type: string x-enum-varnames: - VolumeTypeReadOnly + + RL.ModelId: + type: string + format: uuid + description: Identifier of the current version of the model associated with a training session, that can be used for inference. + example: 123e4567-e89b-12d3-a456-426614174000 + RL.TrainingId: + type: string + format: uuid + description: Identifier for a high-level training session that can be resumed or referenced later + example: 123e4567-e89b-12d3-a456-426614174111 + RL.RequestId: + type: string + format: uuid + example: 123e4567-e89b-12d3-a456-426614174999 + RL.ErrorData: + type: object + properties: + error: + type: string + description: Error code + message: + type: string + description: Human-readable error message + details: + type: object + additionalProperties: true + description: Additional details for debugging + required: + - error + - message + RL.ErrorResponse: + type: object + properties: + error: + $ref: '#/components/schemas/RL.ErrorData' + required: + - error + RL.ForwardBackwardInput: + type: object + required: + - data + - loss_fn + properties: + data: + type: array + items: + $ref: '#/components/schemas/RL.Datum' + description: Array of input data for the forward/backward pass + loss_fn: + $ref: '#/components/schemas/RL.LossFnType' + RL.LossFnType: + type: string + description: Fully qualified function path for the loss function + example: grpo + enum: + - grpo + RL.SessionCreateParams: + type: object + description: | + Specify `base_model` with `lora_config` to create a new LoRA adapter. + Optionally provide `checkpoint_id` to initialize from an existing checkpoint. + properties: + base_model: + type: string + lora_config: + $ref: '#/components/schemas/RL.LoraConfigParam' + user_metadata: + type: object + additionalProperties: {} + nullable: true + type: + type: string + enum: [create_model] + checkpoint_id: + type: string + description: Optional checkpoint handle/URI to initialize weights from + required: [base_model, lora_config] + + RL.LoraConfigParam: + type: object + description: LoRA configuration parameters. + properties: + type: + type: string + enum: ['Lora'] + lora_r: + type: integer + lora_alpha: + type: integer + lora_dropout: + type: number + format: float + default: 0.0 + lora_trainable_modules: + type: string + default: 'all-linear' + required: + - type + - lora_r + - lora_alpha + + RL.ModelInput: + type: object + required: + - chunks + properties: + chunks: + type: array + items: + $ref: '#/components/schemas/RL.ModelInputChunk' + description: Sequence of input chunks + + RL.ModelInputChunk: + oneOf: + - $ref: '#/components/schemas/RL.EncodedTextChunk' + - $ref: '#/components/schemas/RL.ImageAssetPointerChunk' + discriminator: + propertyName: type + mapping: + encoded_text: '#/components/schemas/RL.EncodedTextChunk' + image_asset_pointer: '#/components/schemas/RL.ImageAssetPointerChunk' + RL.EncodedTextChunk: + type: object + required: + - type + - tokens + properties: + type: + type: string + enum: + - encoded_text + tokens: + type: array + items: + type: integer + format: int32 + description: Array of token IDs + example: + - 1234 + - 5678 + - 9012 + RL.ImageAssetPointerChunk: + type: object + required: + - type + - location + - format + - width + - height + - tokens + - process_image_function_name + properties: + type: + type: string + enum: + - image_asset_pointer + location: + type: string + description: Path or URL to the image asset + example: /path/to/image.jpg + format: + type: string + enum: + - png + - jpeg + description: Image format + width: + type: integer + format: int32 + description: Image width in pixels + height: + type: integer + format: int32 + description: Image height in pixels + tokens: + type: integer + format: int32 + description: Number of tokens this image represents + process_image_function_name: + type: string + description: Name of the function to process this image + example: process_image_default + + RL.Datum: + type: object + required: + - model_input + - loss_fn_inputs + properties: + model_input: + $ref: '#/components/schemas/RL.ModelInput' + loss_fn_inputs: + $ref: '#/components/schemas/RL.LossFnInputs' + + RL.LossFnInputs: + type: object + additionalProperties: + $ref: '#/components/schemas/RL.TensorData' + description: Dictionary mapping field names to tensor data + example: + weights: + data: + - 1 + - 1 + - 1 + - 0.5 + - 0 + dtype: float32 + target_tokens: + data: + - 123 + - 456 + - 789 + - 101 + - 202 + dtype: int64 + RL.TensorDType: + type: string + enum: + - int64 + - float32 + RL.TensorData: + type: object + required: + - data + - dtype + properties: + data: + description: Flattened tensor data as array of numbers. + oneOf: + - type: array + items: + type: integer + - type: array + items: + type: number + dtype: + $ref: '#/components/schemas/RL.TensorDType' + shape: + type: array + description: >- + Optional. The shape of the tensor (see PyTorch tensor.shape). The shape of a one-dimensional list + of length N is `(N,)`. Can usually be inferred if not provided, and is generally inferred as a 1D + tensor. + example: + - 10 + items: + type: integer + nullable: true + + RL.UntypedAPIFuture: + type: object + required: + - request_id + - training_session_id + properties: + request_id: + $ref: '#/components/schemas/RL.RequestId' + training_session_id: + $ref: '#/components/schemas/RL.TrainingId' + RL.OptimStepRequest: + type: object + required: + - training_session_id + - adam_params + properties: + type: + type: string + enum: + - optim_step + training_session_id: + $ref: '#/components/schemas/RL.TrainingId' + adam_params: + $ref: '#/components/schemas/RL.RL.AdamParams' + RL.RL.AdamParams: + type: object + required: + - learning_rate + properties: + learning_rate: + type: number + format: float + description: Learning rate for the optimizer + default: 0.0001 + x-stainless-useDefault: true + beta1: + type: number + format: float + description: Coefficient used for computing running averages of gradient + default: 0.9 + x-stainless-useDefault: true + beta2: + type: number + format: float + description: Coefficient used for computing running averages of gradient square + default: 0.95 + x-stainless-useDefault: true + eps: + type: number + format: float + description: Term added to the denominator to improve numerical stability + default: 1e-12 + x-stainless-useDefault: true + RL.ForwardBackwardOutput: + type: object + required: + - metrics + - loss_fn_outputs + - loss_fn_output_type + properties: + metrics: + type: object + additionalProperties: + oneOf: + - type: number + - type: integer + description: Training metrics as key-value pairs + example: + loss: 2.345 + loss_fn_output_type: + $ref: '#/components/schemas/RL.LossFnType' + description: Identifier of the loss function used to compute outputs + loss_fn_outputs: + type: object + additionalProperties: + $ref: '#/components/schemas/RL.TensorData' + description: Dictionary mapping field names to tensor data + RL.SessionStartResponse: + type: object + properties: + training_session_id: + $ref: '#/components/schemas/RL.TrainingId' + model_id: + type: string + $ref: '#/components/schemas/RL.ModelId' + required: [training_session_id, model_id] + RL.TryAgainResponse: + type: object + description: This response is sent when the client has waited too long and the request is still pending. Client should retry retrieve_future. + required: + - type + - request_id + properties: + type: + type: string + enum: + - try_again + default: try_again + request_id: + type: string + description: Request ID that is still pending + RL.OptimStepResponse: + type: object + additionalProperties: + oneOf: + - type: integer + - type: number + - type: string + description: Optimization step results with arbitrary key-value pairs + example: + step: 1000 + learning_rate: 0.00001 + gradient_norm: 0.456 + status: completed + RL.SaveTrainingStateResponse: + type: object + required: + - path + properties: + path: + type: string + description: A handle or URI to the saved training state (weights + optimizer) + example: s3://bucket/checkpoints/session-12345/state-000042.pt