Skip to content

Model Fine-Tuning API

This documentation provides a comprehensive guide to the Model Fine-Tuning API within AI Refinery. The API enables you to customize AI models with your personal data according to your specific requirements. You can access this functionality through our SDK using either the AIRefinery or AsyncAIRefinery clients.

The Fine-Tuning API allows you to:

  • Create custom models tailored to your specific use cases
  • Improve model performance on domain-specific tasks
  • Adapt pre-trained models in Hugging Face to your particular data distribution

Asynchronous Fine-tuning API

Fine-Tuning Job Creation

The AsyncAIRefinery client allows you to submit a fine-tuning job to the computing cluster asynchronously by sending a POST request to the endpoint.

AsyncAIRefinery.fine_tuning.jobs.create()

Parameters:
  • job_config (dict[str, Any] | FineTuningJobConfig): The job description and hyper-parameters for the fine-tuning process. See the template configuration in YAML format below.
  • uuid (str): The unique identifier assigned to the user.
  • timeout (float | None): Maximum time (in seconds) to wait for a response. Defaults to 60 seconds if not provided.
  • extra_headers (dict[str, str] | None): Request-specific headers that override any default headers.
  • **kwargs: Additional parameters.
job_config Parameters:
  • description (str): A brief description of the fine-tuning job.
  • method (str): The fine-tuning method to use (currently, only support "supervised").
  • train_config (dict[str, Any]):
    • hf_api_key (str): Your Hugging Face API key for accessing model repositories.
    • model (str): The base model in Hugging Face to fine-tune (e.g., "NousResearch/Meta-Llama-3.1-8B-Instruct").
    • max_epochs (int): Maximum number of fine-tuning epochs.
    • lr (float): Learning rate for the optimization process.
    • training_file (str): Name of the training dataset stored in the cloud.
    • validation_file (str): Name of the validation dataset stored in the cloud.
    • chat (bool): Using multi-turn chat datasets or not.
    • global_batch_size (int): Total batch size across all devices.
    • max_seq_length (int): Maximum sequence length for input tokens.
Template of job config fine_tuning.yaml
description: "Example fine-tuning job"
method: supervised

train_config:
  "hf_api_key": "your_hugging_face_api_key"  # Replace with your actual API key
  "model": "NousResearch/Meta-Llama-3.1-8B-Instruct"  # Name of model repository in Hugging Face
  "max_epochs": 3   # Maximum epochs for training
  "lr": 2e-5  # Initialized learning rate
  "training_file": "mqa_train_data_v2_with_persona_B_instruct"  # Name of training dataset
  "validation_file": "mqa_train_data_v2_with_persona_B_instruct"  # Name of validation dataset
  "chat": false  # Whether the training dataset is multi-turn conversation dataset
  "global_batch_size": 128  # Batch size 
  "max_seq_length": 1024  # Maximum number of tokens for input
Returns:

The method returns a FineTuningRequest object containing the following attributes:

  • job_id (str): The unique identifier for the fine-tuning job.

  • job_description (str): Description of the fine-tuning job.

  • user_id (str): Unique identifier for the user. (e.g., 'test_user')

  • method (str): The fine-tuning method specified in the user configuration.

  • created_at (str): A formatted timestamp indicating when the job was created.

  • error (str): Error message (empty string if no errors occurred).

  • fine_tuned_model (str | None): Path to the trained model (None while job is in progress).

  • finished_at (str | None): Timestamp when the job completed (None while job is in progress).

  • train_config (dict[str]): The complete fine-tuning configuration parameters.

  • model (str): The base model being trained.

  • seed (int): Random seed used for reproducibility.

  • status (str): Current job status (e.g., "queued", "running", "completed", "failed").

  • training_file (str): The training dataset used.

  • validation_file (str): The validation dataset used.

Fine-Tuning Job Cancellation

The AsyncAIRefinery client allows you to cancel a running fine-tuning job by sending a POST request to the endpoint.

AsyncAIRefinery.fine_tuning.jobs.cancel()

Parameters:
  • uuid (str): The unique identifier assigned to the user.
  • fine_tuning_job_id (str): Unique identifier of the fine-tuning job to cancel.
  • timeout (float | None): Maximum time (in seconds) to wait for a response. Defaults to 60 seconds if not provided.
  • extra_headers (dict[str, str] | None): Request-specific headers that override any default headers.
  • **kwargs: Additional parameters.
Returns:

The method returns a FineTuningRequest object with the updated status indicating the job has been cancelled.

Listing of Fine-Tuning Events

The AsyncAIRefinery client allows you to retrieve all relevant events associated with a specific fine-tuning job.

AsyncAIRefinery.fine_tuning.jobs.list_events()

Parameters:
  • fine_tuning_job_id (str): Unique identifier of the target fine-tuning job.
  • timeout (float | None): Maximum time (in seconds) to wait for a response. Defaults to 60 seconds if not provided.
  • extra_headers (dict[str, str] | None): Request-specific headers that override any default headers.
  • **kwargs: Additional parameters.
Returns:

The method returns a list of job status events in the events field (list[dict]), where each event contains:

  • job_id (str): Unique job identifier.

  • job_description (str): Job description for user's information.

  • user_id (str): Unique user identifier.

  • created_at (str): Timestamp when the event was created.

  • message (str): Description of the event that occurred (e.g., "job created", "job started", "cancelled").

  • finished_at (str): Timestamp when the event was completed.

Example Usage

The following example demonstrates how to use the Fine-Tuning API to create, cancel, and monitor a fine-tuning job:

import os
import asyncio
from omegaconf import OmegaConf
from air import AsyncAIRefinery

# Get API_KEY for AI Refinery service from environment variable
api_key = os.getenv("API_KEY")

async def async_fine_tuning_launch(client: AsyncAIRefinery):
    # Load the user config stored in the yaml file
    config = OmegaConf.load("fine_tuning.yaml")
    job_config = OmegaConf.to_container(config, resolve=True)

    # Use the fine-tuning sub-client to asynchronously submit a job to the computing cluster
    response = await client.fine_tuning.jobs.create(
        job_config=job_config,
        uuid="test_user",
    )

    # Print the response from the fine-tuning request
    print("Async fine-tuning launch response: ", response)

    return response

async def async_fine_tuning_cancel(client: AsyncAIRefinery, cancel_job_id: str, uuid: str):
    # Use the fine-tuning sub-client to cancel the job
    response = await client.fine_tuning.jobs.cancel(
        fine_tuning_job_id=cancel_job_id,
        uuid=uuid,
    )

    # Print the response from the cancel request
    print("Async fine-tuning cancel response: ", response)

    return response

async def async_fine_tuning_list_events(client: AsyncAIRefinery, event_job_id: str, uuid: str):    
    # Use the fine-tuning sub-client to retrieve job events
    response = await client.fine_tuning.jobs.list_events(
        fine_tuning_job_id=event_job_id,
        uuid=uuid,
    )

    # Print the response from the list events request
    print("Async list fine-tuning events response: ", response)

    return response


# Main execution block
if __name__ == "__main__":
    # Initialize the asynchronous client for AI Refinery service with authenticated API-key
    client = AsyncAIRefinery(api_key=api_key)

    # Create and submit a fine-tuning job
    response = asyncio.run(
        async_fine_tuning_launch(client)
    )

    # Cancel the fine-tuning job if needed
    asyncio.run(async_fine_tuning_cancel(client, cancel_job_id=response.job_id, uuid="test_user"))

    # List all events related to the job
    asyncio.run(async_fine_tuning_list_events(client, event_job_id=response.job_id, uuid="test_user"))

Synchronous Fine-tuning API

Synchronous Fine-Tuning Job Creation, Cancellation and Listing Events.

AIRefinery.fine_tuning.jobs.create(), AIRefinery.fine_tuning.jobs.cancel() and AIRefinery.fine_tuning.jobs.list_events()

The AIRefinery client creates, cancels and queries fine-tuning job in a synchronous manner. This method supports the same parameters and return structure as the asynchronous methods described above.

Example Usage
import os
from omegaconf import OmegaConf
from air import AIRefinery

# Get API_KEY for AI Refinery service from environment variable
api_key = os.getenv("API_KEY")

def sync_fine_tuning_launch(client: AIRefinery):
    # Load the user config stored in the yaml file
    config = OmegaConf.load("fine_tuning.yaml")
    job_config = OmegaConf.to_container(config, resolve=True)

    # Use the fine-tuning sub-client to synchronously submit a job to the computing cluster
    response = client.fine_tuning.jobs.create(
        job_config=job_config,
        uuid="test_user",
    )

    # Print the response from the fine-tuning request
    print("Sync fine-tuning launch response: ", response)

    return response

def sync_fine_tuning_cancel(client: AIRefinery, cancel_job_id: str, uuid: str):
    # Use the fine-tuning sub-client to cancel the job
    response = client.fine_tuning.jobs.cancel(
        fine_tuning_job_id=cancel_job_id,
        uuid=uuid,
    )

    # Print the response from the cancel request
    print("Sync fine-tuning cancel response: ", response)

    return response

def sync_fine_tuning_list_events(client: AIRefinery, event_job_id: str, uuid: str):    
    # Use the fine-tuning sub-client to retrieve job events
    response = client.fine_tuning.jobs.list_events(
        fine_tuning_job_id=event_job_id,
        uuid=uuid,
    )

    # Print the response from the list events request
    print("Sync list fine-tuning events response: ", response)

    return response


# Main execution block
if __name__ == "__main__":
    # Initialize the synchronous client for AI Refinery service with authenticated API-key
    client = AIRefinery(api_key=api_key)

    # Create and submit a fine-tuning job
    response = sync_fine_tuning_launch(client=client)

    # Cancel the fine-tuning job if needed
    sync_fine_tuning_cancel(client=client, cancel_job_id=response.job_id, uuid="test_user")

    # List all events related to the job
    sync_fine_tuning_list_events(client=client, event_job_id=response.job_id, uuid="test_user")