mirror of
https://github.com/infiniflow/ragflow.git
synced 2026-05-06 10:17:49 +08:00
Feat: Agent api (#14157)
### What problem does this PR solve?
1. **List agents**
**Prev API**:
- `/v1/canvas/list GET`
- `/api/v1/agents GET`
**Current API**: `/api/v2/agents GET`
2. **Get canvas template**
**Prev API**: `/v1/canvas/templates GET`
**Current API**: `/api/v2/agents/templates GET`
3. **Delete an agent**
**Prev API**:
- `/v1/canvas/rm POST`
- `/api/v1/agents/<agent_id> DELETE`
**Current API**: `/api/v2/agents/<agent_id> DELETE`
4. **Update an agent**
**Prev API**:
- `/api/v1/agents/<agent_id> PUT`
- `/v1/canvas/setting POST `
**Current API**: `/api/v2/agents/<agent_id> PATCH`
5. **Create an agent**
**Prev API**:
- `/v1/canvas/set POST`
- `/api/v1/agents POST`
**Current API**: `/api/v2/agents POST`
6. **Get an agent**
**Prev API**:
- `/v1/canvas/get/<canvas_id> GET `
**Current API**: `/api/v2/agents/<agent_id> GET`
7. **Reset an agent**
**Prev API**:
- `/v1/canvas/reset POST`
**Current API**: `/api/v2/agents/<agent_id>/reset POST`
8. **Upload a file to an agent**
**Prev API**:
- `/v1/canvas/upload/<canvas_id> POST`
**Current API**: `/api/v2/agents/<agent_id>/upload POST`
9. **Input form**
**Prev API**:
- `/v1/canvas/input_form GET`
**Current API**:
`/api/v2/agents/<agent_id>/components/<component_id>/input-form GET`
10. **Debug an agent**
**Prev API**:
- `/v1/canvas/debug POST`
**Current API**:
`/api/v2/agents/<agent_id>/components/<component_id>/debug POST`
11. **Trace an agent**
**Prev API**:
- `/v1/canvas/trace GET`
**Current API**: `/api/v2/agents/<agent_id>/logs/<message_id> GET`
12. **Get an agent version list**
**Prev API**:
- `/v1/canvas/getlistversion/<canvas_id>`
**Current API**: `/api/v2/agents/<agent_id>/versions GET`
13. **Get a version of agent**
**Prev API**:
- `/v1/canvas/getversion/<version_id>`
**Current API**: `/api/v2/agents/<agent_id>/versions/<version_id> GET`
14. **Test db connection**
**Prev API**:
- `/v1/canvas/test_db_connect POST`
**Current API**: `/api/v2/agents/test_db_connection`
15. **Rerun the agent**
**Prev API**:
- `/v1/canvas/rerun POST`
**Current API**: `/api/v2/agents/rerun POST`
16. **Get prompts**
**Prev API**:
- `/v1/canvas/prompts GET`
**Current API**: `/api/v2/agents/prompts GET`
### Type of change
- [x] New Feature (non-breaking change which adds functionality)
---------
Co-authored-by: chanx <1243304602@qq.com>
This commit is contained in:
@ -13,330 +13,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
#
|
||||
import copy
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from functools import partial
|
||||
from quart import request, Response, make_response
|
||||
from agent.component import LLM
|
||||
from api.db import CanvasCategory
|
||||
from api.db.services.canvas_service import CanvasTemplateService, UserCanvasService, API4ConversationService
|
||||
from api.db.services.document_service import DocumentService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.knowledgebase_service import KnowledgebaseService
|
||||
from api.db.services.pipeline_operation_log_service import PipelineOperationLogService
|
||||
from api.db.services.task_service import queue_dataflow, CANVAS_DEBUG_DOC_ID, TaskService
|
||||
from api.db.services.user_service import TenantService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid, thread_pool_exec
|
||||
from api.utils.api_utils import (
|
||||
get_json_result,
|
||||
server_error_response,
|
||||
validate_request,
|
||||
get_data_error_result,
|
||||
get_request_json,
|
||||
)
|
||||
from agent.canvas import Canvas
|
||||
from agent.dsl_migration import normalize_chunker_dsl
|
||||
from peewee import MySQLDatabase, PostgresqlDatabase
|
||||
from api.db.db_models import APIToken, Task
|
||||
|
||||
from rag.flow.pipeline import Pipeline
|
||||
from rag.nlp import search
|
||||
from api.utils.api_utils import get_json_result
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
from common import settings
|
||||
from api.apps import login_required, current_user
|
||||
from api.apps.services.canvas_replica_service import CanvasReplicaService
|
||||
from api.db.services.canvas_service import completion as agent_completion
|
||||
|
||||
|
||||
@manager.route('/templates', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def templates():
|
||||
return get_json_result(data=[c.to_dict() for c in CanvasTemplateService.get_all()])
|
||||
|
||||
|
||||
@manager.route('/rm', methods=['POST']) # noqa: F821
|
||||
@validate_request("canvas_ids")
|
||||
@login_required
|
||||
async def rm():
|
||||
req = await get_request_json()
|
||||
for i in req["canvas_ids"]:
|
||||
if not UserCanvasService.accessible(i, current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
UserCanvasService.delete_by_id(i)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/set', methods=['POST']) # noqa: F821
|
||||
@validate_request("dsl", "title")
|
||||
@login_required
|
||||
async def save():
|
||||
req = await get_request_json()
|
||||
req['release'] = bool(req.get("release", ""))
|
||||
try:
|
||||
req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
|
||||
except ValueError as e:
|
||||
return get_data_error_result(message=str(e))
|
||||
cate = req.get("canvas_category", CanvasCategory.Agent)
|
||||
if "id" not in req:
|
||||
req["user_id"] = current_user.id
|
||||
if UserCanvasService.query(user_id=current_user.id, title=req["title"].strip(), canvas_category=cate):
|
||||
return get_data_error_result(message=f"{req['title'].strip()} already exists.")
|
||||
req["id"] = get_uuid()
|
||||
if not UserCanvasService.save(**req):
|
||||
return get_data_error_result(message="Fail to save canvas.")
|
||||
else:
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
UserCanvasService.update_by_id(req["id"], req)
|
||||
# save version
|
||||
UserCanvasVersionService.save_or_replace_latest(
|
||||
user_canvas_id=req["id"],
|
||||
dsl=req["dsl"],
|
||||
title=UserCanvasVersionService.build_version_title(getattr(current_user, "nickname", current_user.id), req.get("title")),
|
||||
release=req.get("release"),
|
||||
)
|
||||
replica_ok = CanvasReplicaService.replace_for_set(
|
||||
canvas_id=req["id"],
|
||||
tenant_id=str(current_user.id),
|
||||
runtime_user_id=str(current_user.id),
|
||||
dsl=req["dsl"],
|
||||
canvas_category=req.get("canvas_category", cate),
|
||||
title=req.get("title", ""),
|
||||
)
|
||||
if not replica_ok:
|
||||
return get_data_error_result(message="canvas saved, but replica sync failed.")
|
||||
return get_json_result(data=req)
|
||||
|
||||
|
||||
@manager.route('/get/<canvas_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get(canvas_id):
|
||||
if not UserCanvasService.accessible(canvas_id, current_user.id):
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
e, c = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
try:
|
||||
# DELETE
|
||||
CanvasReplicaService.bootstrap(
|
||||
canvas_id=canvas_id,
|
||||
tenant_id=str(current_user.id),
|
||||
runtime_user_id=str(current_user.id),
|
||||
dsl=c.get("dsl"),
|
||||
canvas_category=c.get("canvas_category", CanvasCategory.Agent),
|
||||
title=c.get("title", ""),
|
||||
)
|
||||
except ValueError as e:
|
||||
return get_data_error_result(message=str(e))
|
||||
|
||||
# Get the last publication time (latest released version's update_time)
|
||||
last_publish_time = None
|
||||
versions = UserCanvasVersionService.list_by_canvas_id(canvas_id)
|
||||
if versions:
|
||||
released_versions = [v for v in versions if v.release]
|
||||
if released_versions:
|
||||
# Sort by update_time descending and get the latest
|
||||
released_versions.sort(key=lambda x: x.update_time, reverse=True)
|
||||
last_publish_time = released_versions[0].update_time
|
||||
|
||||
# Add last_publish_time to response data
|
||||
if isinstance(c, dict):
|
||||
c["dsl"] = normalize_chunker_dsl(c.get("dsl", {}))
|
||||
c["last_publish_time"] = last_publish_time
|
||||
else:
|
||||
# If c is a model object, convert to dict first
|
||||
c = c.to_dict()
|
||||
c["dsl"] = normalize_chunker_dsl(c.get("dsl", {}))
|
||||
c["last_publish_time"] = last_publish_time
|
||||
|
||||
# For pipeline type, get associated datasets
|
||||
if c.get("canvas_category") == CanvasCategory.DataFlow:
|
||||
datasets = list(KnowledgebaseService.query(pipeline_id=canvas_id))
|
||||
c["datasets"] = [{"id": d.id, "name": d.name, "avatar": d.avatar} for d in datasets]
|
||||
|
||||
return get_json_result(data=c)
|
||||
|
||||
|
||||
@manager.route('/getsse/<canvas_id>', methods=['GET']) # type: ignore # noqa: F821
|
||||
def getsse(canvas_id):
|
||||
token = request.headers.get('Authorization').split()
|
||||
if len(token) != 2:
|
||||
return get_data_error_result(message='Authorization is not valid!')
|
||||
token = token[1]
|
||||
objs = APIToken.query(beta=token)
|
||||
if not objs:
|
||||
return get_data_error_result(message='Authentication error: API key is invalid!"')
|
||||
tenant_id = objs[0].tenant_id
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=canvas_id):
|
||||
return get_json_result(
|
||||
data=False,
|
||||
message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR
|
||||
)
|
||||
e, c = UserCanvasService.get_by_id(canvas_id)
|
||||
if not e or c.user_id != tenant_id:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
return get_json_result(data=c.to_dict())
|
||||
|
||||
|
||||
@manager.route('/completion', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
async def run():
|
||||
req = await get_request_json()
|
||||
query = req.get("query", "")
|
||||
files = req.get("files", [])
|
||||
inputs = req.get("inputs", {})
|
||||
tenant_id = str(current_user.id)
|
||||
runtime_user_id = req.get("user_id") or tenant_id
|
||||
user_id = str(runtime_user_id)
|
||||
if not await thread_pool_exec(UserCanvasService.accessible, req["id"], tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
replica_payload = CanvasReplicaService.load_for_run(
|
||||
canvas_id=req["id"],
|
||||
tenant_id=tenant_id,
|
||||
runtime_user_id=user_id,
|
||||
)
|
||||
|
||||
if not replica_payload:
|
||||
return get_data_error_result(message="canvas replica not found, please call /get/<canvas_id> first.")
|
||||
|
||||
replica_dsl = replica_payload.get("dsl", {})
|
||||
canvas_title = replica_payload.get("title", "")
|
||||
canvas_category = replica_payload.get("canvas_category", CanvasCategory.Agent)
|
||||
dsl_str = json.dumps(replica_dsl, ensure_ascii=False)
|
||||
|
||||
_, cvs = await thread_pool_exec(UserCanvasService.get_by_id, req["id"])
|
||||
if cvs.canvas_category == CanvasCategory.DataFlow:
|
||||
task_id = get_uuid()
|
||||
Pipeline(dsl_str, tenant_id=tenant_id, doc_id=CANVAS_DEBUG_DOC_ID, task_id=task_id, flow_id=req["id"])
|
||||
ok, error_message = await thread_pool_exec(queue_dataflow, user_id, req["id"], task_id, CANVAS_DEBUG_DOC_ID, files[0], 0)
|
||||
if not ok:
|
||||
return get_data_error_result(message=error_message)
|
||||
return get_json_result(data={"message_id": task_id})
|
||||
|
||||
try:
|
||||
canvas = Canvas(dsl_str, tenant_id, canvas_id=req["id"])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
async def sse():
|
||||
nonlocal canvas, user_id
|
||||
try:
|
||||
async for ans in canvas.run(query=query, files=files, user_id=user_id, inputs=inputs):
|
||||
yield "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
|
||||
commit_ok = CanvasReplicaService.commit_after_run(
|
||||
canvas_id=req["id"],
|
||||
tenant_id=tenant_id,
|
||||
runtime_user_id=user_id,
|
||||
dsl=json.loads(str(canvas)),
|
||||
canvas_category=canvas_category,
|
||||
title=canvas_title,
|
||||
)
|
||||
if not commit_ok:
|
||||
logging.error(
|
||||
"Canvas runtime replica commit failed: canvas_id=%s tenant_id=%s runtime_user_id=%s",
|
||||
req["id"],
|
||||
tenant_id,
|
||||
user_id,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
canvas.cancel_task()
|
||||
yield "data:" + json.dumps({"code": 500, "message": str(e), "data": False}, ensure_ascii=False) + "\n\n"
|
||||
|
||||
resp = Response(sse(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
#resp.call_on_close(lambda: canvas.cancel_task())
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route("/<canvas_id>/completion", methods=["POST"]) # noqa: F821
|
||||
@login_required
|
||||
async def exp_agent_completion(canvas_id):
|
||||
tenant_id = current_user.id
|
||||
req = await get_request_json()
|
||||
return_trace = bool(req.get("return_trace", False))
|
||||
async def generate():
|
||||
trace_items = []
|
||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=canvas_id, **req):
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
ans = json.loads(answer[5:]) # remove "data:"
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
event = ans.get("event")
|
||||
if event == "node_finished":
|
||||
if return_trace:
|
||||
data = ans.get("data", {})
|
||||
trace_items.append(
|
||||
{
|
||||
"component_id": data.get("component_id"),
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
ans.setdefault("data", {})["trace"] = trace_items
|
||||
answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
yield answer
|
||||
|
||||
if event not in ["message", "message_end"]:
|
||||
continue
|
||||
|
||||
yield answer
|
||||
|
||||
yield "data:[DONE]\n\n"
|
||||
|
||||
resp = Response(generate(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
|
||||
@manager.route('/rerun', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "dsl", "component_id")
|
||||
@login_required
|
||||
async def rerun():
|
||||
req = await get_request_json()
|
||||
doc = PipelineOperationLogService.get_documents_info(req["id"])
|
||||
if not doc:
|
||||
return get_data_error_result(message="Document not found.")
|
||||
doc = doc[0]
|
||||
if 0 < doc["progress"] < 1:
|
||||
return get_data_error_result(message=f"`{doc['name']}` is processing...")
|
||||
|
||||
if settings.docStoreConn.index_exist(search.index_name(current_user.id), doc["kb_id"]):
|
||||
settings.docStoreConn.delete({"doc_id": doc["id"]}, search.index_name(current_user.id), doc["kb_id"])
|
||||
doc["progress_msg"] = ""
|
||||
doc["chunk_num"] = 0
|
||||
doc["token_num"] = 0
|
||||
DocumentService.clear_chunk_num_when_rerun(doc["id"])
|
||||
DocumentService.update_by_id(id, doc)
|
||||
TaskService.filter_delete([Task.doc_id == id])
|
||||
|
||||
dsl = req["dsl"]
|
||||
dsl["path"] = [req["component_id"]]
|
||||
PipelineOperationLogService.update_by_id(req["id"], {"dsl": dsl})
|
||||
queue_dataflow(tenant_id=current_user.id, flow_id=req["id"], task_id=get_uuid(), doc_id=doc["id"], priority=0, rerun=True)
|
||||
return get_json_result(data=True)
|
||||
from api.apps import login_required
|
||||
|
||||
|
||||
@manager.route('/cancel/<task_id>', methods=['PUT']) # noqa: F821
|
||||
@ -347,409 +27,3 @@ def cancel(task_id):
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route('/reset', methods=['POST']) # noqa: F821
|
||||
@validate_request("id")
|
||||
@login_required
|
||||
async def reset():
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id)
|
||||
canvas.reset()
|
||||
req["dsl"] = json.loads(str(canvas))
|
||||
UserCanvasService.update_by_id(req["id"], {"dsl": req["dsl"]})
|
||||
return get_json_result(data=req["dsl"])
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route("/upload/<canvas_id>", methods=["POST"]) # noqa: F821
|
||||
async def upload(canvas_id):
|
||||
e, cvs = UserCanvasService.get_by_canvas_id(canvas_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
|
||||
user_id = cvs["user_id"]
|
||||
files = await request.files
|
||||
file_objs = files.getlist("file") if files and files.get("file") else []
|
||||
try:
|
||||
if len(file_objs) == 1:
|
||||
return get_json_result(data=FileService.upload_info(user_id, file_objs[0], request.args.get("url")))
|
||||
results = [FileService.upload_info(user_id, f) for f in file_objs]
|
||||
return get_json_result(data=results)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/input_form', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def input_form():
|
||||
cvs_id = request.args.get("id")
|
||||
cpn_id = request.args.get("component_id")
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(cvs_id)
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
if not UserCanvasService.query(user_id=current_user.id, id=cvs_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id)
|
||||
return get_json_result(data=canvas.get_component_input_form(cpn_id))
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/debug', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "component_id", "params")
|
||||
@login_required
|
||||
async def debug():
|
||||
req = await get_request_json()
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
try:
|
||||
e, user_canvas = UserCanvasService.get_by_id(req["id"])
|
||||
canvas = Canvas(json.dumps(user_canvas.dsl), current_user.id, canvas_id=user_canvas.id)
|
||||
canvas.reset()
|
||||
canvas.message_id = get_uuid()
|
||||
component = canvas.get_component(req["component_id"])["obj"]
|
||||
component.reset()
|
||||
|
||||
if isinstance(component, LLM):
|
||||
component.set_debug_inputs(req["params"])
|
||||
component.invoke(**{k: o["value"] for k,o in req["params"].items()})
|
||||
outputs = component.output()
|
||||
for k in outputs.keys():
|
||||
if isinstance(outputs[k], partial):
|
||||
txt = ""
|
||||
iter_obj = outputs[k]()
|
||||
if inspect.isasyncgen(iter_obj):
|
||||
async for c in iter_obj:
|
||||
txt += c
|
||||
else:
|
||||
for c in iter_obj:
|
||||
txt += c
|
||||
outputs[k] = txt
|
||||
return get_json_result(data=outputs)
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/test_db_connect', methods=['POST']) # noqa: F821
|
||||
@validate_request("db_type", "database", "username", "host", "port", "password")
|
||||
@login_required
|
||||
async def test_db_connect():
|
||||
req = await get_request_json()
|
||||
try:
|
||||
if req["db_type"] in ["mysql", "mariadb"]:
|
||||
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||
password=req["password"])
|
||||
elif req["db_type"] == "oceanbase":
|
||||
db = MySQLDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||
password=req["password"], charset="utf8mb4")
|
||||
elif req["db_type"] == 'postgres':
|
||||
db = PostgresqlDatabase(req["database"], user=req["username"], host=req["host"], port=req["port"],
|
||||
password=req["password"])
|
||||
elif req["db_type"] == 'mssql':
|
||||
import pyodbc
|
||||
connection_string = (
|
||||
f"DRIVER={{ODBC Driver 17 for SQL Server}};"
|
||||
f"SERVER={req['host']},{req['port']};"
|
||||
f"DATABASE={req['database']};"
|
||||
f"UID={req['username']};"
|
||||
f"PWD={req['password']};"
|
||||
)
|
||||
db = pyodbc.connect(connection_string)
|
||||
cursor = db.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
cursor.close()
|
||||
elif req["db_type"] == 'IBM DB2':
|
||||
import ibm_db
|
||||
conn_str = (
|
||||
f"DATABASE={req['database']};"
|
||||
f"HOSTNAME={req['host']};"
|
||||
f"PORT={req['port']};"
|
||||
f"PROTOCOL=TCPIP;"
|
||||
f"UID={req['username']};"
|
||||
f"PWD={req['password']};"
|
||||
)
|
||||
redacted_conn_str = (
|
||||
f"DATABASE={req['database']};"
|
||||
f"HOSTNAME={req['host']};"
|
||||
f"PORT={req['port']};"
|
||||
f"PROTOCOL=TCPIP;"
|
||||
f"UID={req['username']};"
|
||||
f"PWD=****;"
|
||||
)
|
||||
logging.info(redacted_conn_str)
|
||||
conn = ibm_db.connect(conn_str, "", "")
|
||||
stmt = ibm_db.exec_immediate(conn, "SELECT 1 FROM sysibm.sysdummy1")
|
||||
ibm_db.fetch_assoc(stmt)
|
||||
ibm_db.close(conn)
|
||||
return get_json_result(data="Database Connection Successful!")
|
||||
elif req["db_type"] == 'trino':
|
||||
def _parse_catalog_schema(db_name: str):
|
||||
if not db_name:
|
||||
return None, None
|
||||
if "." in db_name:
|
||||
catalog_name, schema_name = db_name.split(".", 1)
|
||||
elif "/" in db_name:
|
||||
catalog_name, schema_name = db_name.split("/", 1)
|
||||
else:
|
||||
catalog_name, schema_name = db_name, "default"
|
||||
return catalog_name, schema_name
|
||||
try:
|
||||
import trino
|
||||
import os
|
||||
except Exception as e:
|
||||
return server_error_response(f"Missing dependency 'trino'. Please install: pip install trino, detail: {e}")
|
||||
|
||||
catalog, schema = _parse_catalog_schema(req["database"])
|
||||
if not catalog:
|
||||
return server_error_response("For Trino, 'database' must be 'catalog.schema' or at least 'catalog'.")
|
||||
|
||||
http_scheme = "https" if os.environ.get("TRINO_USE_TLS", "0") == "1" else "http"
|
||||
|
||||
auth = None
|
||||
if http_scheme == "https" and req.get("password"):
|
||||
auth = trino.BasicAuthentication(req.get("username") or "ragflow", req["password"])
|
||||
|
||||
conn = trino.dbapi.connect(
|
||||
host=req["host"],
|
||||
port=int(req["port"] or 8080),
|
||||
user=req["username"] or "ragflow",
|
||||
catalog=catalog,
|
||||
schema=schema or "default",
|
||||
http_scheme=http_scheme,
|
||||
auth=auth
|
||||
)
|
||||
cur = conn.cursor()
|
||||
cur.execute("SELECT 1")
|
||||
cur.fetchall()
|
||||
cur.close()
|
||||
conn.close()
|
||||
return get_json_result(data="Database Connection Successful!")
|
||||
else:
|
||||
return server_error_response("Unsupported database type.")
|
||||
if req["db_type"] != 'mssql':
|
||||
db.connect()
|
||||
db.close()
|
||||
|
||||
return get_json_result(data="Database Connection Successful!")
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
#api get list version dsl of canvas
|
||||
@manager.route('/getlistversion/<canvas_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def getlistversion(canvas_id):
|
||||
try:
|
||||
versions =sorted([c.to_dict() for c in UserCanvasVersionService.list_by_canvas_id(canvas_id)], key=lambda x: x["update_time"]*-1)
|
||||
return get_json_result(data=versions)
|
||||
except Exception as e:
|
||||
return get_data_error_result(message=f"Error getting history files: {e}")
|
||||
|
||||
|
||||
#api get version dsl of canvas
|
||||
@manager.route('/getversion/<version_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def getversion( version_id):
|
||||
try:
|
||||
e, version = UserCanvasVersionService.get_by_id(version_id)
|
||||
if version:
|
||||
return get_json_result(data=version.to_dict())
|
||||
except Exception as e:
|
||||
return get_json_result(data=f"Error getting history file: {e}")
|
||||
|
||||
|
||||
@manager.route('/list', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def list_canvas():
|
||||
keywords = request.args.get("keywords", "")
|
||||
page_number = int(request.args.get("page", 0))
|
||||
items_per_page = int(request.args.get("page_size", 0))
|
||||
orderby = request.args.get("orderby", "create_time")
|
||||
canvas_category = request.args.get("canvas_category")
|
||||
if request.args.get("desc", "true").lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
owner_ids = [id for id in request.args.get("owner_ids", "").strip().split(",") if id]
|
||||
if not owner_ids:
|
||||
tenants = TenantService.get_joined_tenants_by_user_id(current_user.id)
|
||||
tenants = [m["tenant_id"] for m in tenants]
|
||||
tenants.append(current_user.id)
|
||||
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||
tenants, current_user.id, page_number,
|
||||
items_per_page, orderby, desc, keywords, canvas_category)
|
||||
else:
|
||||
tenants = owner_ids
|
||||
canvas, total = UserCanvasService.get_by_tenant_ids(
|
||||
tenants, current_user.id, 0,
|
||||
0, orderby, desc, keywords, canvas_category)
|
||||
return get_json_result(data={"canvas": canvas, "total": total})
|
||||
|
||||
|
||||
@manager.route('/setting', methods=['POST']) # noqa: F821
|
||||
@validate_request("id", "title", "permission")
|
||||
@login_required
|
||||
async def setting():
|
||||
req = await get_request_json()
|
||||
req["user_id"] = current_user.id
|
||||
|
||||
if not UserCanvasService.accessible(req["id"], current_user.id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
e,flow = UserCanvasService.get_by_id(req["id"])
|
||||
if not e:
|
||||
return get_data_error_result(message="canvas not found.")
|
||||
flow = flow.to_dict()
|
||||
flow["title"] = req["title"]
|
||||
|
||||
for key in ["description", "permission", "avatar"]:
|
||||
if value := req.get(key):
|
||||
flow[key] = value
|
||||
|
||||
num= UserCanvasService.update_by_id(req["id"], flow)
|
||||
return get_json_result(data=num)
|
||||
|
||||
|
||||
@manager.route('/trace', methods=['GET']) # noqa: F821
|
||||
def trace():
|
||||
cvs_id = request.args.get("canvas_id")
|
||||
msg_id = request.args.get("message_id")
|
||||
try:
|
||||
binary = REDIS_CONN.get(f"{cvs_id}-{msg_id}-logs")
|
||||
if not binary:
|
||||
return get_json_result(data={})
|
||||
|
||||
return get_json_result(data=json.loads(binary.encode("utf-8")))
|
||||
except Exception as e:
|
||||
logging.exception(e)
|
||||
|
||||
|
||||
@manager.route('/<canvas_id>/sessions', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def sessions(canvas_id):
|
||||
tenant_id = current_user.id
|
||||
if not UserCanvasService.accessible(canvas_id, tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
user_id = request.args.get("user_id")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
keywords = request.args.get("keywords")
|
||||
from_date = request.args.get("from_date")
|
||||
to_date = request.args.get("to_date")
|
||||
orderby = request.args.get("orderby", "update_time")
|
||||
exp_user_id = request.args.get("exp_user_id")
|
||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
|
||||
if exp_user_id:
|
||||
sess = API4ConversationService.get_names(canvas_id, exp_user_id)
|
||||
return get_json_result(data={"total": len(sess), "sessions": sess})
|
||||
|
||||
# dsl defaults to True in all cases except for False and false
|
||||
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
|
||||
total, sess = API4ConversationService.get_list(canvas_id, tenant_id, page_number, items_per_page, orderby, desc,
|
||||
None, user_id, include_dsl, keywords, from_date, to_date, exp_user_id=exp_user_id)
|
||||
try:
|
||||
return get_json_result(data={"total": total, "sessions": sess})
|
||||
except Exception as e:
|
||||
return server_error_response(e)
|
||||
|
||||
|
||||
@manager.route('/<canvas_id>/sessions', methods=['PUT']) # noqa: F821
|
||||
@login_required
|
||||
async def set_session(canvas_id):
|
||||
req = await get_request_json()
|
||||
tenant_id = current_user.id
|
||||
e, cvs = UserCanvasService.get_by_id(canvas_id)
|
||||
assert e, "Agent not found."
|
||||
if not isinstance(cvs.dsl, str):
|
||||
cvs.dsl = json.dumps(cvs.dsl, ensure_ascii=False)
|
||||
session_id=get_uuid()
|
||||
canvas = Canvas(cvs.dsl, tenant_id, canvas_id, canvas_id=cvs.id)
|
||||
canvas.reset()
|
||||
# Get the version title for this canvas (using latest, not necessarily released)
|
||||
version_title = UserCanvasVersionService.get_latest_version_title(cvs.id, release_mode=False)
|
||||
conv = {
|
||||
"id": session_id,
|
||||
"name": req.get("name", ""),
|
||||
"dialog_id": cvs.id,
|
||||
"user_id": tenant_id,
|
||||
"exp_user_id": tenant_id,
|
||||
"message": [],
|
||||
"source": "agent",
|
||||
"dsl": cvs.dsl,
|
||||
"reference": [],
|
||||
"version_title": version_title
|
||||
}
|
||||
API4ConversationService.save(**conv)
|
||||
return get_json_result(data=conv)
|
||||
|
||||
|
||||
@manager.route('/<canvas_id>/sessions/<session_id>', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def get_session(canvas_id, session_id):
|
||||
tenant_id = current_user.id
|
||||
if not UserCanvasService.accessible(canvas_id, tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
_, conv = API4ConversationService.get_by_id(session_id)
|
||||
return get_json_result(data=conv.to_dict())
|
||||
|
||||
|
||||
@manager.route('/<canvas_id>/sessions/<session_id>', methods=['DELETE']) # noqa: F821
|
||||
@login_required
|
||||
def del_session(canvas_id, session_id):
|
||||
tenant_id = current_user.id
|
||||
if not UserCanvasService.accessible(canvas_id, tenant_id):
|
||||
return get_json_result(
|
||||
data=False, message='Only owner of canvas authorized for this operation.',
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
return get_json_result(data=API4ConversationService.delete_by_id(session_id))
|
||||
|
||||
|
||||
@manager.route('/prompts', methods=['GET']) # noqa: F821
|
||||
@login_required
|
||||
def prompts():
|
||||
from rag.prompts.generator import ANALYZE_TASK_SYSTEM, ANALYZE_TASK_USER, NEXT_STEP, REFLECT, CITATION_PROMPT_TEMPLATE
|
||||
|
||||
return get_json_result(data={
|
||||
"task_analysis": ANALYZE_TASK_SYSTEM +"\n\n"+ ANALYZE_TASK_USER,
|
||||
"plan_generation": NEXT_STEP,
|
||||
"reflection": REFLECT,
|
||||
#"context_summary": SUMMARY4MEMORY,
|
||||
#"context_ranking": RANK_MEMORY,
|
||||
"citation_guidelines": CITATION_PROMPT_TEMPLATE
|
||||
})
|
||||
|
||||
|
||||
@manager.route('/download', methods=['GET']) # noqa: F821
|
||||
async def download():
|
||||
id = request.args.get("id")
|
||||
created_by = request.args.get("created_by")
|
||||
blob = FileService.get_blob(created_by, id)
|
||||
return await make_response(blob)
|
||||
|
||||
1047
api/apps/restful_apis/agent_api.py
Normal file
1047
api/apps/restful_apis/agent_api.py
Normal file
File diff suppressed because it is too large
Load Diff
@ -22,137 +22,18 @@ import ipaddress
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import Any, cast
|
||||
|
||||
import jwt
|
||||
|
||||
from agent.canvas import Canvas
|
||||
from api.apps.services.canvas_replica_service import CanvasReplicaService
|
||||
from api.db import CanvasCategory
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.file_service import FileService
|
||||
from api.db.services.user_service import UserService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
from common.constants import RetCode
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import get_data_error_result, get_error_data_result, get_json_result, get_request_json, token_required
|
||||
from api.utils.api_utils import get_result
|
||||
from api.utils.api_utils import get_data_error_result, get_json_result
|
||||
from quart import request, Response
|
||||
from rag.utils.redis_conn import REDIS_CONN
|
||||
|
||||
|
||||
def _get_user_nickname(user_id: str) -> str:
|
||||
exists, user = UserService.get_by_id(user_id)
|
||||
if not exists:
|
||||
return user_id
|
||||
return str(getattr(user, "nickname", "") or user_id)
|
||||
|
||||
|
||||
@manager.route('/agents', methods=['GET']) # noqa: F821
|
||||
@token_required
|
||||
def list_agents(tenant_id):
|
||||
id = request.args.get("id")
|
||||
title = request.args.get("title")
|
||||
if id or title:
|
||||
canvas = UserCanvasService.query(id=id, title=title, user_id=tenant_id)
|
||||
if not canvas:
|
||||
return get_error_data_result("The agent doesn't exist.")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
order_by = request.args.get("orderby", "update_time")
|
||||
if str(request.args.get("desc","false")).lower() == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
canvas = UserCanvasService.get_list(tenant_id, page_number, items_per_page, order_by, desc, id, title)
|
||||
return get_result(data=canvas)
|
||||
|
||||
|
||||
@manager.route("/agents", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def create_agent(tenant_id: str):
|
||||
req: dict[str, Any] = cast(dict[str, Any], await get_request_json())
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
try:
|
||||
req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
|
||||
except ValueError as e:
|
||||
return get_json_result(data=False, message=str(e), code=RetCode.ARGUMENT_ERROR)
|
||||
else:
|
||||
return get_json_result(data=False, message="No DSL data in request.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if req.get("title") is not None:
|
||||
req["title"] = req["title"].strip()
|
||||
else:
|
||||
return get_json_result(data=False, message="No title in request.", code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if UserCanvasService.query(user_id=tenant_id, title=req["title"]):
|
||||
return get_data_error_result(message=f"Agent with title {req['title']} already exists.")
|
||||
|
||||
agent_id = get_uuid()
|
||||
req["id"] = agent_id
|
||||
|
||||
if not UserCanvasService.save(**req):
|
||||
return get_data_error_result(message="Fail to create agent.")
|
||||
|
||||
owner_nickname = _get_user_nickname(tenant_id)
|
||||
UserCanvasVersionService.save_or_replace_latest(
|
||||
user_canvas_id=agent_id,
|
||||
title=UserCanvasVersionService.build_version_title(owner_nickname, req.get("title")),
|
||||
dsl=req["dsl"]
|
||||
)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>", methods=["PUT"]) # noqa: F821
|
||||
@token_required
|
||||
async def update_agent(tenant_id: str, agent_id: str):
|
||||
req: dict[str, Any] = {k: v for k, v in cast(dict[str, Any], (await get_request_json())).items() if v is not None}
|
||||
req["user_id"] = tenant_id
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
try:
|
||||
req["dsl"] = CanvasReplicaService.normalize_dsl(req["dsl"])
|
||||
except ValueError as e:
|
||||
return get_json_result(data=False, message=str(e), code=RetCode.ARGUMENT_ERROR)
|
||||
|
||||
if req.get("title") is not None:
|
||||
req["title"] = req["title"].strip()
|
||||
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
return get_json_result(
|
||||
data=False, message="Only owner of canvas authorized for this operation.",
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
_, current_agent = UserCanvasService.get_by_id(agent_id)
|
||||
agent_title_for_version = req.get("title") or (current_agent.title if current_agent else "")
|
||||
owner_nickname = _get_user_nickname(tenant_id)
|
||||
|
||||
UserCanvasService.update_by_id(agent_id, req)
|
||||
|
||||
if req.get("dsl") is not None:
|
||||
UserCanvasVersionService.save_or_replace_latest(
|
||||
user_canvas_id=agent_id,
|
||||
title=UserCanvasVersionService.build_version_title(owner_nickname, agent_title_for_version),
|
||||
dsl=req["dsl"]
|
||||
)
|
||||
|
||||
return get_json_result(data=True)
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
def delete_agent(tenant_id: str, agent_id: str):
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
return get_json_result(
|
||||
data=False, message="Only owner of canvas authorized for this operation.",
|
||||
code=RetCode.OPERATING_ERROR)
|
||||
|
||||
UserCanvasService.delete_by_id(agent_id)
|
||||
return get_json_result(data=True)
|
||||
|
||||
@manager.route("/webhook/<agent_id>", methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"]) # noqa: F821
|
||||
@manager.route("/webhook_test/<agent_id>",methods=["POST", "GET", "PUT", "PATCH", "DELETE", "HEAD"],) # noqa: F821
|
||||
async def webhook(agent_id: str):
|
||||
|
||||
@ -14,7 +14,6 @@
|
||||
# limitations under the License.
|
||||
#
|
||||
import json
|
||||
import copy
|
||||
import re
|
||||
import time
|
||||
|
||||
@ -29,7 +28,7 @@ from common.token_utils import num_tokens_from_string
|
||||
from agent.canvas import Canvas
|
||||
from api.db.db_models import APIToken
|
||||
from api.db.services.api_service import API4ConversationService
|
||||
from api.db.services.canvas_service import UserCanvasService, completion_openai
|
||||
from api.db.services.canvas_service import UserCanvasService
|
||||
from api.db.services.canvas_service import completion as agent_completion
|
||||
from api.db.services.conversation_service import ConversationService
|
||||
from api.db.services.user_canvas_version import UserCanvasVersionService
|
||||
@ -45,7 +44,7 @@ from api.db.services.user_service import UserTenantService
|
||||
from api.db.joint_services.tenant_model_service import get_tenant_default_model_by_type, get_model_config_by_id, \
|
||||
get_model_config_by_type_and_name
|
||||
from common.misc_utils import get_uuid
|
||||
from api.utils.api_utils import check_duplicate_ids, get_data_openai, get_error_data_result, get_json_result, \
|
||||
from api.utils.api_utils import check_duplicate_ids, get_error_data_result, get_json_result, \
|
||||
get_result, get_request_json, server_error_response, token_required, validate_request
|
||||
from rag.app.tag import label_question
|
||||
from rag.prompts.template import load_prompt
|
||||
@ -54,7 +53,6 @@ from common.constants import RetCode, LLMType, StatusEnum
|
||||
from common import settings
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def create_agent_session(tenant_id, agent_id):
|
||||
req = await get_request_json()
|
||||
@ -435,215 +433,6 @@ async def chat_completion_openai_like(tenant_id, chat_id):
|
||||
return jsonify(response)
|
||||
|
||||
|
||||
@manager.route("/agents_openai/<agent_id>/chat/completions", methods=["POST"]) # noqa: F821
|
||||
@validate_request("model", "messages") # noqa: F821
|
||||
@token_required
|
||||
async def agents_completion_openai_compatibility(tenant_id, agent_id):
|
||||
req = await get_request_json()
|
||||
messages = req.get("messages", [])
|
||||
if not messages:
|
||||
return get_error_data_result("You must provide at least one message.")
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
return get_error_data_result(f"You don't own the agent {agent_id}")
|
||||
|
||||
filtered_messages = [m for m in messages if m["role"] in ["user", "assistant"]]
|
||||
prompt_tokens = sum(num_tokens_from_string(m["content"]) for m in filtered_messages)
|
||||
if not filtered_messages:
|
||||
return jsonify(
|
||||
get_data_openai(
|
||||
id=agent_id,
|
||||
content="No valid messages found (user or assistant).",
|
||||
finish_reason="stop",
|
||||
model=req.get("model", ""),
|
||||
completion_tokens=num_tokens_from_string("No valid messages found (user or assistant)."),
|
||||
prompt_tokens=prompt_tokens,
|
||||
)
|
||||
)
|
||||
|
||||
question = next((m["content"] for m in reversed(messages) if m["role"] == "user"), "")
|
||||
|
||||
stream = req.pop("stream", False)
|
||||
if stream:
|
||||
resp = Response(
|
||||
completion_openai(
|
||||
tenant_id,
|
||||
agent_id,
|
||||
question,
|
||||
session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""),
|
||||
stream=True,
|
||||
**req,
|
||||
),
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
else:
|
||||
# For non-streaming, just return the response directly
|
||||
async for response in completion_openai(
|
||||
tenant_id,
|
||||
agent_id,
|
||||
question,
|
||||
session_id=req.pop("session_id", req.get("id", "")) or req.get("metadata", {}).get("id", ""),
|
||||
stream=False,
|
||||
**req,
|
||||
):
|
||||
return jsonify(response)
|
||||
|
||||
return None
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>/completions", methods=["POST"]) # noqa: F821
|
||||
@token_required
|
||||
async def agent_completions(tenant_id, agent_id):
|
||||
req = await get_request_json()
|
||||
return_trace = bool(req.get("return_trace", False))
|
||||
|
||||
if req.get("stream", True):
|
||||
|
||||
async def generate():
|
||||
trace_items = []
|
||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
if isinstance(answer, str):
|
||||
try:
|
||||
ans = json.loads(answer[5:]) # remove "data:"
|
||||
except Exception:
|
||||
continue
|
||||
|
||||
event = ans.get("event")
|
||||
if event == "node_finished":
|
||||
if return_trace:
|
||||
data = ans.get("data", {})
|
||||
trace_items.append(
|
||||
{
|
||||
"component_id": data.get("component_id"),
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
ans.setdefault("data", {})["trace"] = trace_items
|
||||
answer = "data:" + json.dumps(ans, ensure_ascii=False) + "\n\n"
|
||||
yield answer
|
||||
|
||||
if event not in ["message", "message_end"]:
|
||||
continue
|
||||
|
||||
yield answer
|
||||
|
||||
yield "data:[DONE]\n\n"
|
||||
|
||||
resp = Response(generate(), mimetype="text/event-stream")
|
||||
resp.headers.add_header("Cache-control", "no-cache")
|
||||
resp.headers.add_header("Connection", "keep-alive")
|
||||
resp.headers.add_header("X-Accel-Buffering", "no")
|
||||
resp.headers.add_header("Content-Type", "text/event-stream; charset=utf-8")
|
||||
return resp
|
||||
|
||||
full_content = ""
|
||||
reference = {}
|
||||
final_ans = ""
|
||||
trace_items = []
|
||||
structured_output = {}
|
||||
async for answer in agent_completion(tenant_id=tenant_id, agent_id=agent_id, **req):
|
||||
try:
|
||||
ans = json.loads(answer[5:])
|
||||
|
||||
if ans["event"] == "message":
|
||||
full_content += ans["data"]["content"]
|
||||
|
||||
if ans.get("data", {}).get("reference", None):
|
||||
reference.update(ans["data"]["reference"])
|
||||
|
||||
if ans.get("event") == "node_finished":
|
||||
data = ans.get("data", {})
|
||||
node_out = data.get("outputs", {})
|
||||
component_id = data.get("component_id")
|
||||
if component_id is not None and "structured" in node_out:
|
||||
structured_output[component_id] = copy.deepcopy(node_out["structured"])
|
||||
if return_trace:
|
||||
trace_items.append(
|
||||
{
|
||||
"component_id": data.get("component_id"),
|
||||
"trace": [copy.deepcopy(data)],
|
||||
}
|
||||
)
|
||||
|
||||
final_ans = ans
|
||||
except Exception as e:
|
||||
return get_result(data=f"**ERROR**: {str(e)}")
|
||||
final_ans["data"]["content"] = full_content
|
||||
final_ans["data"]["reference"] = reference
|
||||
if structured_output:
|
||||
final_ans["data"]["structured"] = structured_output
|
||||
if return_trace and final_ans:
|
||||
final_ans["data"]["trace"] = trace_items
|
||||
return get_result(data=final_ans)
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["GET"]) # noqa: F821
|
||||
@token_required
|
||||
async def list_agent_session(tenant_id, agent_id):
|
||||
if not UserCanvasService.query(user_id=tenant_id, id=agent_id):
|
||||
return get_error_data_result(message=f"You don't own the agent {agent_id}.")
|
||||
id = request.args.get("id")
|
||||
user_id = request.args.get("user_id")
|
||||
page_number = int(request.args.get("page", 1))
|
||||
items_per_page = int(request.args.get("page_size", 30))
|
||||
orderby = request.args.get("orderby", "update_time")
|
||||
if request.args.get("desc") == "False" or request.args.get("desc") == "false":
|
||||
desc = False
|
||||
else:
|
||||
desc = True
|
||||
# dsl defaults to True in all cases except for False and false
|
||||
include_dsl = request.args.get("dsl") != "False" and request.args.get("dsl") != "false"
|
||||
total, convs = API4ConversationService.get_list(agent_id, tenant_id, page_number, items_per_page, orderby, desc, id,
|
||||
user_id, include_dsl)
|
||||
if not convs:
|
||||
return get_result(data=[])
|
||||
for conv in convs:
|
||||
conv["messages"] = conv.pop("message")
|
||||
infos = conv["messages"]
|
||||
for info in infos:
|
||||
if "prompt" in info:
|
||||
info.pop("prompt")
|
||||
conv["agent_id"] = conv.pop("dialog_id")
|
||||
# Fix for session listing endpoint
|
||||
if conv["reference"]:
|
||||
messages = conv["messages"]
|
||||
message_num = 0
|
||||
chunk_num = 0
|
||||
# Ensure reference is a list type to prevent KeyError
|
||||
if not isinstance(conv["reference"], list):
|
||||
conv["reference"] = []
|
||||
while message_num < len(messages):
|
||||
if message_num != 0 and messages[message_num]["role"] != "user":
|
||||
chunk_list = []
|
||||
# Add boundary and type checks to prevent KeyError
|
||||
if chunk_num < len(conv["reference"]) and conv["reference"][chunk_num] is not None and isinstance(
|
||||
conv["reference"][chunk_num], dict) and "chunks" in conv["reference"][chunk_num]:
|
||||
chunks = conv["reference"][chunk_num]["chunks"]
|
||||
for chunk in chunks:
|
||||
# Ensure chunk is a dictionary before calling get method
|
||||
if not isinstance(chunk, dict):
|
||||
continue
|
||||
new_chunk = {
|
||||
"id": chunk.get("chunk_id", chunk.get("id")),
|
||||
"content": chunk.get("content_with_weight", chunk.get("content")),
|
||||
"document_id": chunk.get("doc_id", chunk.get("document_id")),
|
||||
"document_name": chunk.get("docnm_kwd", chunk.get("document_name")),
|
||||
"dataset_id": chunk.get("kb_id", chunk.get("dataset_id")),
|
||||
"image_id": chunk.get("image_id", chunk.get("img_id")),
|
||||
"positions": chunk.get("positions", chunk.get("position_int")),
|
||||
}
|
||||
chunk_list.append(new_chunk)
|
||||
chunk_num += 1
|
||||
messages[message_num]["reference"] = chunk_list
|
||||
message_num += 1
|
||||
del conv["reference"]
|
||||
return get_result(data=convs)
|
||||
|
||||
|
||||
@manager.route("/agents/<agent_id>/sessions", methods=["DELETE"]) # noqa: F821
|
||||
@token_required
|
||||
async def delete_agent_session(tenant_id, agent_id):
|
||||
|
||||
Reference in New Issue
Block a user