import json
import requests
from requests.adapters import HTTPAdapter
from urllib3.util import Retry
import time
import re
[docs]
class ApiException(Exception):
"""
Custom exception class for handling API-specific errors.
:param msg: Error message describing the API failure
:param error_code: HTTP status code from the API response
"""
def __init__(self, msg, error_code):
self.msg = msg
self.error_code = error_code
[docs]
class ApiProxy():
"""
A proxy class for handling HTTP communications with LLM API endpoints.
This class manages API request sessions, implements retry logic, and handles
various HTTP status codes with exponential backoff for rate limits and server errors.
:param url: The API endpoint URL
:param api_key: Optional API key for authentication with LLM services
"""
def __init__(self, url, api_key=None):
retry_strategy = Retry(
total=1,
backoff_factor=1,
status_forcelist=[429, 500, 502, 503, 504],
allowed_methods=["POST"],
)
adapter = HTTPAdapter()
self.session = requests.Session()
self.session.mount("https://", adapter)
self.session.mount("http://", adapter)
self.api_key = api_key
self.url = url
def __call__(self, params_gpt, headers={}):
headers['Content-Type'] = headers['Content-Type'] if 'Content-Type' in headers else 'application/json'
if self.api_key:
headers['Authorization'] = "Bearer " + self.api_key
try:
response = self.session.post(self.url, headers=headers, data=json.dumps(params_gpt))
except Exception as e:
time.sleep(10)
response = self.session.post(self.url, headers=headers, data=json.dumps(params_gpt))
if response.status_code in (429, 404, 500, 502, 503, 504, 104):
for idx in range(10000000000000000):
interval = min(1, 10+idx*5)
if idx % 10 == 0:
print("meet error 429 or 404, code: {}, msg: {},run {}".format(
response.status_code, response.text, interval, idx
), flush=True)
response = self.session.post(self.url, headers=headers, data=json.dumps(params_gpt))
if response.status_code not in (429, 404, 500, 502, 503, 504, 104):
break
if response.status_code != 200:
err_msg = f"access error, status code: {response.status_code}, errmsg: {response.text}"
raise ApiException(err_msg, response.status_code)
data = json.loads(response.text)
return data
[docs]
class LMAgent():
"""
A high-level interface for interacting with large language models.
:param config: Dictionary containing model configuration parameters
"""
def __init__(self, config):
self.model = config["model_name"]
self.apikey = config["api_key"]
self.temperature = config["temperature"]
self.max_new_tokens = config["max_new_tokens"]
self.top_p = config["top_p"]
self._proxy = ApiProxy(url=config['api_url']+'/v1/chat/completions', api_key=self.apikey)
def __call__(self, system_prompt, input_prompt, max_new_tokens):
if isinstance(input_prompt, list):
system_message, input_message = input_prompt[0].split("## Input Data")
else:
system_message, input_message = input_prompt.split("## Input Data")
messages = [
{"role": "system", "content": system_message},
{"role": "user", "content": "## Input Data"+input_message}
]
prompt_dict = {
"model": self.model,
"messages": messages,
"temperature": self.temperature,
"max_tokens": self.max_new_tokens,
"top_p": self.top_p,
"frequency_penalty": 0.0,
"presence_penalty": 0.0,
}
response = self._proxy(prompt_dict)
content = response["choices"][0]["message"]["content"]
pattern = r'\{"rerank_list": "[^"]+"\}'
match = re.search(pattern, content, re.DOTALL)
if match:
# Convert the string to a dictionary
result_dict = match.group()
return [result_dict] if isinstance(input_prompt, list) else result_dict
else:
return [content] if isinstance(input_prompt, list) else content