Model Fine-Tuning API¶
This documentation provides a comprehensive guide to the Model Fine-Tuning API within AI Refinery. The API enables you to customize AI models with your personal data according to your specific requirements. You can access this functionality through our SDK using either the AIRefinery
or AsyncAIRefinery
clients.
The Fine-Tuning API allows you to:
- Create custom models tailored to your specific use cases
- Improve model performance on domain-specific tasks
- Adapt pre-trained models in Hugging Face to your particular data distribution
Asynchronous Fine-tuning API¶
Fine-Tuning Job Creation¶
The AsyncAIRefinery
client allows you to submit a fine-tuning job to the computing cluster asynchronously by sending a POST request to the endpoint.
AsyncAIRefinery.fine_tuning.jobs.create()
¶
Parameters:¶
job_config
(dict[str, Any] | FineTuningJobConfig): The job description and hyper-parameters for the fine-tuning process. See the template configuration in YAML format below.uuid
(str): The unique identifier assigned to the user.timeout
(float | None): Maximum time (in seconds) to wait for a response. Defaults to 60 seconds if not provided.extra_headers
(dict[str, str] | None): Request-specific headers that override any default headers.**kwargs
: Additional parameters.
job_config Parameters:¶
description
(str): A brief description of the fine-tuning job.method
(str): The fine-tuning method to use (currently, only support "supervised").train_config
(dict[str, Any]):hf_api_key
(str): Your Hugging Face API key for accessing model repositories.model
(str): The base model in Hugging Face to fine-tune (e.g., "NousResearch/Meta-Llama-3.1-8B-Instruct").max_epochs
(int): Maximum number of fine-tuning epochs.lr
(float): Learning rate for the optimization process.training_file
(str): Name of the training dataset stored in the cloud.validation_file
(str): Name of the validation dataset stored in the cloud.chat
(bool): Using multi-turn chat datasets or not.global_batch_size
(int): Total batch size across all devices.max_seq_length
(int): Maximum sequence length for input tokens.
Template of job config fine_tuning.yaml
¶
description: "Example fine-tuning job"
method: supervised
train_config:
"hf_api_key": "your_hugging_face_api_key" # Replace with your actual API key
"model": "NousResearch/Meta-Llama-3.1-8B-Instruct" # Name of model repository in Hugging Face
"max_epochs": 3 # Maximum epochs for training
"lr": 2e-5 # Initialized learning rate
"training_file": "mqa_train_data_v2_with_persona_B_instruct" # Name of training dataset
"validation_file": "mqa_train_data_v2_with_persona_B_instruct" # Name of validation dataset
"chat": false # Whether the training dataset is multi-turn conversation dataset
"global_batch_size": 128 # Batch size
"max_seq_length": 1024 # Maximum number of tokens for input
Returns:¶
The method returns a FineTuningRequest
object containing the following attributes:
-
job_id
(str): The unique identifier for the fine-tuning job. -
job_description
(str): Description of the fine-tuning job. -
user_id
(str): Unique identifier for the user. (e.g., 'test_user') -
method
(str): The fine-tuning method specified in the user configuration. -
created_at
(str): A formatted timestamp indicating when the job was created. -
error
(str): Error message (empty string if no errors occurred). -
fine_tuned_model
(str | None): Path to the trained model (None while job is in progress). -
finished_at
(str | None): Timestamp when the job completed (None while job is in progress). -
train_config
(dict[str]): The complete fine-tuning configuration parameters. -
model
(str): The base model being trained. -
seed
(int): Random seed used for reproducibility. -
status
(str): Current job status (e.g., "queued", "running", "completed", "failed"). -
training_file
(str): The training dataset used. -
validation_file
(str): The validation dataset used.
Fine-Tuning Job Cancellation¶
The AsyncAIRefinery
client allows you to cancel a running fine-tuning job by sending a POST request to the endpoint.
AsyncAIRefinery.fine_tuning.jobs.cancel()
¶
Parameters:¶
uuid
(str): The unique identifier assigned to the user.fine_tuning_job_id
(str): Unique identifier of the fine-tuning job to cancel.timeout
(float | None): Maximum time (in seconds) to wait for a response. Defaults to 60 seconds if not provided.extra_headers
(dict[str, str] | None): Request-specific headers that override any default headers.**kwargs
: Additional parameters.
Returns:¶
The method returns a FineTuningRequest
object with the updated status indicating the job has been cancelled.
Listing of Fine-Tuning Events¶
The AsyncAIRefinery
client allows you to retrieve all relevant events associated with a specific fine-tuning job.
AsyncAIRefinery.fine_tuning.jobs.list_events()
¶
Parameters:¶
fine_tuning_job_id
(str): Unique identifier of the target fine-tuning job.timeout
(float | None): Maximum time (in seconds) to wait for a response. Defaults to 60 seconds if not provided.extra_headers
(dict[str, str] | None): Request-specific headers that override any default headers.**kwargs
: Additional parameters.
Returns:¶
The method returns a list of job status events in the events
field (list[dict]), where each event contains:
-
job_id
(str): Unique job identifier. -
job_description
(str): Job description for user's information. -
user_id
(str): Unique user identifier. -
created_at
(str): Timestamp when the event was created. -
message
(str): Description of the event that occurred (e.g., "job created", "job started", "cancelled"). -
finished_at
(str): Timestamp when the event was completed.
Example Usage¶
The following example demonstrates how to use the Fine-Tuning API to create, cancel, and monitor a fine-tuning job:
import os
import asyncio
from omegaconf import OmegaConf
from air import AsyncAIRefinery
# Get API_KEY for AI Refinery service from environment variable
api_key = os.getenv("API_KEY")
async def async_fine_tuning_launch(client: AsyncAIRefinery):
# Load the user config stored in the yaml file
config = OmegaConf.load("fine_tuning.yaml")
job_config = OmegaConf.to_container(config, resolve=True)
# Use the fine-tuning sub-client to asynchronously submit a job to the computing cluster
response = await client.fine_tuning.jobs.create(
job_config=job_config,
uuid="test_user",
)
# Print the response from the fine-tuning request
print("Async fine-tuning launch response: ", response)
return response
async def async_fine_tuning_cancel(client: AsyncAIRefinery, cancel_job_id: str, uuid: str):
# Use the fine-tuning sub-client to cancel the job
response = await client.fine_tuning.jobs.cancel(
fine_tuning_job_id=cancel_job_id,
uuid=uuid,
)
# Print the response from the cancel request
print("Async fine-tuning cancel response: ", response)
return response
async def async_fine_tuning_list_events(client: AsyncAIRefinery, event_job_id: str, uuid: str):
# Use the fine-tuning sub-client to retrieve job events
response = await client.fine_tuning.jobs.list_events(
fine_tuning_job_id=event_job_id,
uuid=uuid,
)
# Print the response from the list events request
print("Async list fine-tuning events response: ", response)
return response
# Main execution block
if __name__ == "__main__":
# Initialize the asynchronous client for AI Refinery service with authenticated API-key
client = AsyncAIRefinery(api_key=api_key)
# Create and submit a fine-tuning job
response = asyncio.run(
async_fine_tuning_launch(client)
)
# Cancel the fine-tuning job if needed
asyncio.run(async_fine_tuning_cancel(client, cancel_job_id=response.job_id, uuid="test_user"))
# List all events related to the job
asyncio.run(async_fine_tuning_list_events(client, event_job_id=response.job_id, uuid="test_user"))
Synchronous Fine-tuning API¶
Synchronous Fine-Tuning Job Creation, Cancellation and Listing Events.¶
AIRefinery.fine_tuning.jobs.create()
, AIRefinery.fine_tuning.jobs.cancel()
and AIRefinery.fine_tuning.jobs.list_events()
¶
The AIRefinery
client creates, cancels and queries fine-tuning job in a synchronous manner. This method supports the same parameters and return structure as the asynchronous methods described above.
Example Usage¶
import os
from omegaconf import OmegaConf
from air import AIRefinery
# Get API_KEY for AI Refinery service from environment variable
api_key = os.getenv("API_KEY")
def sync_fine_tuning_launch(client: AIRefinery):
# Load the user config stored in the yaml file
config = OmegaConf.load("fine_tuning.yaml")
job_config = OmegaConf.to_container(config, resolve=True)
# Use the fine-tuning sub-client to synchronously submit a job to the computing cluster
response = client.fine_tuning.jobs.create(
job_config=job_config,
uuid="test_user",
)
# Print the response from the fine-tuning request
print("Sync fine-tuning launch response: ", response)
return response
def sync_fine_tuning_cancel(client: AIRefinery, cancel_job_id: str, uuid: str):
# Use the fine-tuning sub-client to cancel the job
response = client.fine_tuning.jobs.cancel(
fine_tuning_job_id=cancel_job_id,
uuid=uuid,
)
# Print the response from the cancel request
print("Sync fine-tuning cancel response: ", response)
return response
def sync_fine_tuning_list_events(client: AIRefinery, event_job_id: str, uuid: str):
# Use the fine-tuning sub-client to retrieve job events
response = client.fine_tuning.jobs.list_events(
fine_tuning_job_id=event_job_id,
uuid=uuid,
)
# Print the response from the list events request
print("Sync list fine-tuning events response: ", response)
return response
# Main execution block
if __name__ == "__main__":
# Initialize the synchronous client for AI Refinery service with authenticated API-key
client = AIRefinery(api_key=api_key)
# Create and submit a fine-tuning job
response = sync_fine_tuning_launch(client=client)
# Cancel the fine-tuning job if needed
sync_fine_tuning_cancel(client=client, cancel_job_id=response.job_id, uuid="test_user")
# List all events related to the job
sync_fine_tuning_list_events(client=client, event_job_id=response.job_id, uuid="test_user")