Skip to content

VertexAIChatCompletionBatch

langbatch.vertexai.VertexAIChatCompletionBatch

Bases: VertexAIBatch, ChatCompletionBatch

VertexAIChatCompletionBatch is a class for Vertex AI chat completion batches. Can be used for batch processing with Gemini 1.5 Flash and Gemini 1.5 Pro models

Usage:

batch = VertexAIChatCompletionBatch("path/to/file.jsonl")
batch.start()

Source code in langbatch\vertexai.py
class VertexAIChatCompletionBatch(VertexAIBatch, ChatCompletionBatch):
    """
    VertexAIChatCompletionBatch is a class for Vertex AI chat completion batches.
    Can be used for batch processing with Gemini 1.5 Flash and Gemini 1.5 Pro models

    Usage:
    ```python
    batch = VertexAIChatCompletionBatch("path/to/file.jsonl")
    batch.start()
    ```
    """
    _url: str = "/v1/chat/completions"

    def _convert_request(self, req: dict) -> str:
        custom_schema = {
            "contents": [],
            "systemInstruction": None,
            "tools": [],
            "generationConfig": {}
        }
        request = json.loads(VertexAIChatCompletionRequest(**req["body"]).model_dump_json())

        # Track tool responses to match with tool calls
        tool_responses = {}

        # First pass - collect tool responses
        for message in request["messages"]:
            if message["role"] == "tool":
                tool_call_id = message["tool_call_id"]
                if tool_call_id:
                    tool_responses[tool_call_id] = {
                        "response": json.loads(message["content"])
                    }

        # Second pass - process messages
        for message in request["messages"]:
            role = message["role"]
            content = message["content"]

            function_calls = []
            tool_responses_cache = []
            if message.get("tool_calls"):
                for tool_call in message["tool_calls"]:
                    function_calls.append({
                        "functionCall": {
                            "name": tool_call["function"]["name"],
                            "args": json.loads(tool_call["function"]["arguments"])
                        }
                    })
                    # If we have a response for this tool call, add it in the next message
                    if tool_call["id"] in tool_responses:
                        response = tool_responses[tool_call["id"]]
                        response["name"] = tool_call["function"]["name"]
                        tool_responses_cache.append({
                            "functionResponse": response
                        })

            if role == "system":
                custom_schema["systemInstruction"] = {
                    "role": "system",
                    "parts": {"text": content}
                }
            elif role != "tool":  # Skip tool messages as we handle them separately
                if len(function_calls) > 0:
                    custom_schema["contents"].append({
                        "role": role,
                        "parts": function_calls
                    })
                    custom_schema["contents"].append({
                        "role": "model",
                        "parts": tool_responses_cache
                    })
                elif content is not None:
                    custom_schema["contents"].append({
                        "role": role,
                        "parts": {"text": content}
                    })

        # Convert tools
        if request["tools"]:
            for tool in request["tools"]:
                function = tool.get("function", {})

                custom_schema["tools"].append({
                    "functionDeclarations": [{
                        "name": function.get("name"),
                        "description": function.get("description", ""),
                        "parameters": function.get("parameters", {})
                    }]
                })

        # Convert generation config
        gen_config = custom_schema["generationConfig"]
        if request.get("temperature"):
            gen_config["temperature"] = request["temperature"]
        if request.get("top_p"):
            gen_config["topP"] = request["top_p"]
        if request.get("max_tokens"):
            gen_config["maxOutputTokens"] = request["max_tokens"]
        if request.get("n"):
            gen_config["candidateCount"] = request["n"]
        if request.get("presence_penalty"):
            gen_config["presencePenalty"] = request["presence_penalty"]
        if request.get("frequency_penalty"):
            gen_config["frequencyPenalty"] = request["frequency_penalty"]
        if request.get("stop"):
            gen_config["stopSequences"] = request["stop"] if isinstance(request["stop"], list) else [request["stop"]] if request["stop"] else None
        if request.get("seed"):
            gen_config["seed"] = request["seed"]

        if request.get("response_format"):
            mime_type_map = {
                "json_object": "application/json",
                "text": "text/plain",
                "json_schema": "application/json"
            }

            gen_config["responseMimeType"] = mime_type_map[request["response_format"]["type"]]

            if request["response_format"]["type"] == "json_schema" and request["response_format"]["json_schema"]:
                gen_config["responseSchema"] = request["response_format"]["json_schema"]

                # Check for single enum property to use text/x.enum mime type
                data = json.loads(request["response_format"]["json_schema"]["schema"])
                concrete_types = ["string", "number", "integer", "boolean"]
                if data.get("type") in concrete_types and len(data.get("enum", [])) == 0:
                    gen_config["responseMimeType"] = "text/x.enum"


        gen_config = {k: v for k, v in gen_config.items() if v is not None}
        custom_schema["generationConfig"] = gen_config

        return { 
            "custom_id": req["custom_id"],
            "request": json.dumps(custom_schema, indent=2)
        }

    def _convert_response(self, response):
        # Parse the input JSON
        response_data = json.loads(response["response"])

        status = response["status"]
        if status != "":
            if "Bad Request: " in status:
                error_data = json.loads(status.split("Bad Request: ")[1])
            else:
                error_data = {
                    "message": status,
                    "code": "server_error"
                }

            error = {
                "message": error_data["error"]["message"],
                "code": error_data["error"]["code"]
            }

            res = None
        else:
            # Extract relevant information
            candidates = response_data["candidates"]
            tokens = response_data["usageMetadata"]

            # Create the choices array
            choices = []
            for index, candidate in enumerate(candidates):
                choice = {
                    "index": index,
                    "logprobs": None,
                    "finish_reason": candidate["finishReason"].lower()
                }

                tool_calls = []
                text_part = None
                for part in candidate["content"]["parts"]:
                    if part.get("functionCall", None):
                        tool_call = {
                            "type": "function",
                            "function": {
                                "name": part["functionCall"].get("name"),
                                "arguments": json.dumps(part["functionCall"].get("args", {}))
                            }
                        }
                        tool_calls.append(tool_call)
                    else:
                        text_part = part["text"]

                message = {
                    "role": "assistant",
                    "content": text_part
                }
                if len(tool_calls) > 0:
                    message["tool_calls"] = tool_calls

                choice["message"] = message
                choices.append(choice)

            usage = {
                "prompt_tokens": tokens.get("promptTokenCount", 0),
                "completion_tokens": tokens.get("candidatesTokenCount", 0),
                "total_tokens": tokens.get("totalTokenCount", 0)
            }

            # Create the body
            body = {
                "id": f'{response["custom_id"]}',
                "object": "chat.completion",
                "created": int(response["processed_time"].timestamp()),
                "model": self.model,
                "system_fingerprint": None,
                "choices": choices,
                "usage": usage
            }

            res = {
                "request_id": response["custom_id"],
                "status_code": 200,
                "body": body,
            }

            error = None

        # create output
        output = {
            "id": f'{response["custom_id"]}',
            "custom_id": response["custom_id"],
            "response": res,
            "error": error
        }

        return output

    def _validate_request(self, request):
        VertexAIChatCompletionRequest(**request)

platform_batch_id class-attribute instance-attribute

platform_batch_id: str | None = None

id instance-attribute

id = str(uuid4())

model instance-attribute

model = model

gcp_project instance-attribute

gcp_project = gcp_project

bigquery_input_dataset instance-attribute

bigquery_input_dataset = bigquery_input_dataset

bigquery_output_dataset instance-attribute

bigquery_output_dataset = bigquery_output_dataset

__init__

__init__(file: str, model: str, gcp_project: str, bigquery_input_dataset: str, bigquery_output_dataset: str) -> None

Initialize the VertexAIBatch class.

Parameters:

  • file (str) –

    The path to the jsonl file in Vertex AI batch format.

  • model (str) –

    The name of the model to use for the batch prediction.

  • gcp_project (str) –

    The GCP project to use for the batch prediction.

  • bigquery_input_dataset (str) –

    The BigQuery dataset to use for the batch prediction input.

  • bigquery_output_dataset (str) –

    The BigQuery dataset to use for the batch prediction output.

Usage:

batch = VertexAIBatch(
    "path/to/file.jsonl",
    "model",
    "gcp_project",
    "bigquery_input_dataset",
    "bigquery_output_dataset"
)

Source code in langbatch\vertexai.py
def __init__(self, file: str, model: str, gcp_project: str, bigquery_input_dataset: str, bigquery_output_dataset: str) -> None:
    """
    Initialize the VertexAIBatch class.

    Args:
        file (str): The path to the jsonl file in Vertex AI batch format.
        model (str): The name of the model to use for the batch prediction.
        gcp_project (str): The GCP project to use for the batch prediction.
        bigquery_input_dataset (str): The BigQuery dataset to use for the batch prediction input.
        bigquery_output_dataset (str): The BigQuery dataset to use for the batch prediction output.

    Usage:
    ```python
    batch = VertexAIBatch(
        "path/to/file.jsonl",
        "model",
        "gcp_project",
        "bigquery_input_dataset",
        "bigquery_output_dataset"
    )
    ```
    """
    super().__init__(file)

    self.model = model
    self.gcp_project = gcp_project
    self.bigquery_input_dataset = bigquery_input_dataset
    self.bigquery_output_dataset = bigquery_output_dataset

create_from_requests classmethod

create_from_requests(requests, batch_kwargs: Dict = {})

Creates a batch when given a list of requests. These requests should be in correct Batch API request format as per the Batch type. Ex. for OpenAIChatCompletionBatch, requests should be a Chat Completion request with custom_id.

Parameters:

  • requests –

    A list of requests.

  • batch_kwargs (Dict, default: {} ) –

    Additional keyword arguments for the batch class. Ex. gcp_project, etc. for VertexAIChatCompletionBatch.

Returns:

  • –

    An instance of the Batch class.

Raises:

  • BatchInitializationError –

    If the input data is invalid.

Usage:

batch = OpenAIChatCompletionBatch.create_from_requests([
    {   "custom_id": "request-1",
        "method": "POST",
        "url": "/v1/chat/completions",
        "body": {
            "model": "gpt-4o-mini",
            "messages": [{"role": "user", "content": "Biryani Receipe, pls."}],
            "max_tokens": 1000
        }
    },
    {
        "custom_id": "request-2",
        "method": "POST",
        "url": "/v1/chat/completions",
        "body": {
            "model": "gpt-4o-mini",
            "messages": [{"role": "user", "content": "Write a short story about AI"}],
            "max_tokens": 1000
        }
    }
]

Source code in langbatch\Batch.py
@classmethod
def create_from_requests(cls, requests, batch_kwargs: Dict = {}):
    """
    Creates a batch when given a list of requests. 
    These requests should be in correct Batch API request format as per the Batch type.
    Ex. for OpenAIChatCompletionBatch, requests should be a Chat Completion request with custom_id.

    Args:
        requests: A list of requests.
        batch_kwargs (Dict, optional): Additional keyword arguments for the batch class. Ex. gcp_project, etc. for VertexAIChatCompletionBatch.

    Returns:
        An instance of the Batch class.

    Raises:
        BatchInitializationError: If the input data is invalid.

    Usage:
    ```python
    batch = OpenAIChatCompletionBatch.create_from_requests([
        {   "custom_id": "request-1",
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                "model": "gpt-4o-mini",
                "messages": [{"role": "user", "content": "Biryani Receipe, pls."}],
                "max_tokens": 1000
            }
        },
        {
            "custom_id": "request-2",
            "method": "POST",
            "url": "/v1/chat/completions",
            "body": {
                "model": "gpt-4o-mini",
                "messages": [{"role": "user", "content": "Write a short story about AI"}],
                "max_tokens": 1000
            }
        }
    ]
    ``` 
    """

    file_path = cls._create_batch_file_from_requests(requests)

    if file_path is None:
        raise BatchInitializationError("Failed to create batch. Check the input data.")

    return cls(file_path, **batch_kwargs)

load classmethod

load(id: str, storage: BatchStorage = FileBatchStorage(), batch_kwargs: Dict = {})

Load a batch from the storage and return a Batch object.

Parameters:

  • id (str) –

    The id of the batch.

  • storage (BatchStorage, default: FileBatchStorage() ) –

    The storage to load the batch from. Defaults to FileBatchStorage().

  • batch_kwargs (Dict, default: {} ) –

    Additional keyword arguments for the batch class. Ex. gcp_project, etc. for VertexAIChatCompletionBatch.

Returns:

  • Batch –

    The batch object.

Usage:

batch = OpenAIChatCompletionBatch.load("123", storage=FileBatchStorage("./data"))

Source code in langbatch\Batch.py
@classmethod
def load(cls, id: str, storage: BatchStorage = FileBatchStorage(), batch_kwargs: Dict = {}):
    """
    Load a batch from the storage and return a Batch object.

    Args:
        id (str): The id of the batch.
        storage (BatchStorage, optional): The storage to load the batch from. Defaults to FileBatchStorage().
        batch_kwargs (Dict, optional): Additional keyword arguments for the batch class. Ex. gcp_project, etc. for VertexAIChatCompletionBatch.

    Returns:
        Batch: The batch object.

    Usage:
    ```python
    batch = OpenAIChatCompletionBatch.load("123", storage=FileBatchStorage("./data"))
    ```
    """
    data_file, meta_file = storage.load(id)

    # Load metadata based on file extension
    if meta_file.suffix == '.json':
        with open(meta_file, 'r') as f:
            meta_data = json.load(f)
    else:  # .pkl
        with open(meta_file, 'rb') as f:
            meta_data = pickle.load(f)

    init_args = cls._get_init_args(meta_data)

    for key, value in batch_kwargs.items():
        if key not in init_args:
            init_args[key] = value

    batch = cls(str(data_file), **init_args)
    batch.platform_batch_id = meta_data['platform_batch_id']
    batch.id = id

    return batch

save

save(storage: BatchStorage = FileBatchStorage())

Save the batch to the storage.

Parameters:

Usage:

batch = OpenAIChatCompletionBatch(file)
batch.save()

# save the batch to file storage
batch.save(storage=FileBatchStorage("./data"))

Source code in langbatch\Batch.py
def save(self, storage: BatchStorage = FileBatchStorage()):
    """
    Save the batch to the storage.

    Args:
        storage (BatchStorage, optional): The storage to save the batch to. Defaults to FileBatchStorage().

    Usage:
    ```python
    batch = OpenAIChatCompletionBatch(file)
    batch.save()

    # save the batch to file storage
    batch.save(storage=FileBatchStorage("./data"))
    ```
    """
    meta_data = self._create_meta_data()
    meta_data["platform_batch_id"] = self.platform_batch_id

    storage.save(self.id, Path(self._file), meta_data)

start

start()
Source code in langbatch\vertexai.py
def start(self):
    if self.platform_batch_id is not None:
        raise BatchStateError("Batch already started")

    input_dataset = self._upload_batch_file()
    output_dataset_id = self._create_table(self.bigquery_output_dataset)
    output_dataset = f"bq://{self.gcp_project}.{self.bigquery_output_dataset}.{output_dataset_id}"
    self._create_batch(input_dataset, output_dataset)

get_status

get_status()
Source code in langbatch\vertexai.py
def get_status(self):
    if self.platform_batch_id is None:
        raise BatchStateError("Batch not started")

    job = BatchPredictionJob(self.platform_batch_id)
    return vertexai_state_map[str(job.state.name)]

get_results_file

get_results_file()

Usage:

import jsonlines

# create a batch and start batch process
batch = OpenAIChatCompletionBatch(file)
batch.start()

if batch.get_status() == "completed":
    # get the results file
    results_file = batch.get_results_file()

    with jsonlines.open(results_file) as reader:
        for obj in reader:
            print(obj)

Source code in langbatch\Batch.py
def get_results_file(self):
    """
    Usage:
    ```python
    import jsonlines

    # create a batch and start batch process
    batch = OpenAIChatCompletionBatch(file)
    batch.start()

    if batch.get_status() == "completed":
        # get the results file
        results_file = batch.get_results_file()

        with jsonlines.open(results_file) as reader:
            for obj in reader:
                print(obj)
    ```
    """
    file_path = self._download_results_file()
    return file_path

get_results

get_results() -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]] | Tuple[None, None]

Retrieve the results of the chat completion batch.

Returns:

  • Tuple[List[Dict[str, Any]], List[Dict[str, Any]]] | Tuple[None, None] –

    A tuple containing successful and unsuccessful results. Successful results: A list of dictionaries with "choices" and "custom_id" keys. Unsuccessful results: A list of dictionaries with "error" and "custom_id" keys.

Usage:

successful_results, unsuccessful_results = batch.get_results()
for result in successful_results:
    print(result["choices"])

Source code in langbatch\ChatCompletionBatch.py
def get_results(self) -> Tuple[List[Dict[str, Any]], List[Dict[str, Any]]] | Tuple[None, None]:
    """
    Retrieve the results of the chat completion batch.

    Returns:
        A tuple containing successful and unsuccessful results. Successful results: A list of dictionaries with "choices" and "custom_id" keys. Unsuccessful results: A list of dictionaries with "error" and "custom_id" keys.

    Usage:
    ```python
    successful_results, unsuccessful_results = batch.get_results()
    for result in successful_results:
        print(result["choices"])
    ```
    """
    process_func = lambda result: {"choices": result['response']['body']['choices']}
    return self._prepare_results(process_func)

is_retryable_failure

is_retryable_failure() -> bool
Source code in langbatch\vertexai.py
def is_retryable_failure(self) -> bool:
    # TODO: implement retry logic for Vertex AI API
    error = self._get_errors()
    if error:
        logging.error(f"Error in VertexAI Batch: {error}")
        if "Failed to import data. Not found: Dataset" in error:
            return False
        else:
            return False
    else:
        return False

retry

retry()
Source code in langbatch\vertexai.py
def retry(self):
    if self.platform_batch_id is None:
        raise BatchStateError("Batch not started")

    job = BatchPredictionJob(self.platform_batch_id)
    input_dataset = job._gca_resource.input_config.bigquery_source.input_uri
    output_dataset = job._gca_resource.output_config.bigquery_destination.output_uri

    self._create_batch(input_dataset, output_dataset)

get_unsuccessful_requests

get_unsuccessful_requests() -> List[Dict[str, Any]]

Retrieve the unsuccessful requests from the batch.

Returns:

  • List[Dict[str, Any]] –

    A list of requests that failed.

Usage:

batch = OpenAIChatCompletionBatch(file)
batch.start()

if batch.get_status() == "completed":
    # get the unsuccessful requests
    unsuccessful_requests = batch.get_unsuccessful_requests()

    for request in unsuccessful_requests:
        print(request["custom_id"])

Source code in langbatch\Batch.py
def get_unsuccessful_requests(self) -> List[Dict[str, Any]]:
    """
    Retrieve the unsuccessful requests from the batch.

    Returns:
        A list of requests that failed.

    Usage:
    ```python
    batch = OpenAIChatCompletionBatch(file)
    batch.start()

    if batch.get_status() == "completed":
        # get the unsuccessful requests
        unsuccessful_requests = batch.get_unsuccessful_requests()

        for request in unsuccessful_requests:
            print(request["custom_id"])
    ```
    """
    custom_ids = []
    _, unsuccessful_results = self.get_results()
    for result in unsuccessful_results:
        custom_ids.append(result["custom_id"])

    return self.get_requests_by_custom_ids(custom_ids)

get_requests_by_custom_ids

get_requests_by_custom_ids(custom_ids: List[str]) -> List[Dict[str, Any]]

Retrieve the requests from the batch file by custom ids.

Parameters:

  • custom_ids (List[str]) –

    A list of custom ids.

Returns:

  • List[Dict[str, Any]] –

    A list of requests.

Usage:

batch = OpenAIChatCompletionBatch(file)
batch.start()

if batch.get_status() == "completed":
    # get the requests by custom ids
    requests = batch.get_requests_by_custom_ids(["custom_id1", "custom_id2"])

    for request in requests:
        print(request["custom_id"])

Source code in langbatch\Batch.py
def get_requests_by_custom_ids(self, custom_ids: List[str]) -> List[Dict[str, Any]]:
    """
    Retrieve the requests from the batch file by custom ids.

    Args:
        custom_ids (List[str]): A list of custom ids.

    Returns:
        A list of requests.

    Usage:
    ```python
    batch = OpenAIChatCompletionBatch(file)
    batch.start()

    if batch.get_status() == "completed":
        # get the requests by custom ids
        requests = batch.get_requests_by_custom_ids(["custom_id1", "custom_id2"])

        for request in requests:
            print(request["custom_id"])
    ```
    """
    requests = []
    with jsonlines.open(self._file) as reader:
        for request in reader:
            if request["custom_id"] in custom_ids:
                requests.append(request)
    return requests

create classmethod

create(data: List[Iterable[ChatCompletionMessageParam]], request_kwargs: Dict = {}, batch_kwargs: Dict = {}) -> ChatCompletionBatch

Create a chat completion batch when given a list of messages.

Parameters:

  • data (List[Iterable[ChatCompletionMessageParam]]) –

    A list of messages to be sent to the API.

  • request_kwargs (Dict, default: {} ) –

    Additional keyword arguments for the API call. Ex. model, messages, etc.

  • batch_kwargs (Dict, default: {} ) –

    Additional keyword arguments for the batch class. Ex. gcp_project, etc. for VertexAIChatCompletionBatch.

Returns:

Raises:

  • BatchInitializationError –

    If the input data is invalid.

Usage:

batch = OpenAIChatCompletionBatch.create([
        [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of France?"}],
        [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of Germany?"}]
    ],
    request_kwargs={"model": "gpt-4o"})

# For Vertex AI
batch = VertexAIChatCompletionBatch.create([
        [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of France?"}],
        [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of Germany?"}]
    ],
    request_kwargs={"model": "gemini-2.0-flash-001"},
    batch_kwargs={
        "gcp_project": "your-gcp-project", 
        "bigquery_input_dataset": "your-bigquery-input-dataset", 
        "bigquery_output_dataset": "your-bigquery-output-dataset"
    })

Source code in langbatch\ChatCompletionBatch.py
@classmethod
def create(cls, data: List[Iterable[ChatCompletionMessageParam]], request_kwargs: Dict = {}, batch_kwargs: Dict = {}) -> "ChatCompletionBatch":
    """
    Create a chat completion batch when given a list of messages.

    Args:
        data (List[Iterable[ChatCompletionMessageParam]]): A list of messages to be sent to the API.
        request_kwargs (Dict): Additional keyword arguments for the API call. Ex. model, messages, etc.
        batch_kwargs (Dict): Additional keyword arguments for the batch class. Ex. gcp_project, etc. for VertexAIChatCompletionBatch.

    Returns:
        An instance of the ChatCompletionBatch class.

    Raises:
        BatchInitializationError: If the input data is invalid.

    Usage:
    ```python
    batch = OpenAIChatCompletionBatch.create([
            [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of France?"}],
            [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of Germany?"}]
        ],
        request_kwargs={"model": "gpt-4o"})

    # For Vertex AI
    batch = VertexAIChatCompletionBatch.create([
            [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of France?"}],
            [{"role": "system", "content": "You are a helpful assistant."}, {"role": "user", "content": "What is the capital of Germany?"}]
        ],
        request_kwargs={"model": "gemini-2.0-flash-001"},
        batch_kwargs={
            "gcp_project": "your-gcp-project", 
            "bigquery_input_dataset": "your-bigquery-input-dataset", 
            "bigquery_output_dataset": "your-bigquery-output-dataset"
        })
    ```
    """
    return cls._create_batch_file("messages", data, request_kwargs, batch_kwargs)