190 lines
6.8 KiB
PHP
190 lines
6.8 KiB
PHP
<?php
|
|
|
|
namespace KupShop\LLMBundle\Controller;
|
|
|
|
use KupShop\AdminBundle\AdminRequiredControllerInterface;
|
|
use KupShop\KupShopBundle\Routing\AdminRoute;
|
|
use KupShop\LLMBundle\Dto\PromptResponseStats;
|
|
use KupShop\LLMBundle\Dto\TextPromptInput;
|
|
use KupShop\LLMBundle\Enum\PromptPlaceholder;
|
|
use KupShop\LLMBundle\Util\LlmProviderUtil;
|
|
use KupShop\LLMBundle\Util\TextObjectUtil;
|
|
use KupShop\LLMBundle\Util\TextPromptUtil;
|
|
use KupShop\MetricsBundle\PrometheusWrapper;
|
|
use Psr\Log\LoggerInterface;
|
|
use Symfony\Bundle\FrameworkBundle\Controller\AbstractController;
|
|
use Symfony\Component\HttpFoundation\JsonResponse;
|
|
use Symfony\Component\HttpFoundation\Request;
|
|
use Symfony\Component\HttpFoundation\Response;
|
|
use Symfony\Component\HttpFoundation\StreamedResponse;
|
|
use Symfony\Component\HttpKernel\Exception\BadRequestHttpException;
|
|
use Symfony\Contracts\Service\Attribute\Required;
|
|
|
|
class AdminLlmController extends AbstractController implements AdminRequiredControllerInterface
|
|
{
|
|
#[Required]
|
|
public TextObjectUtil $textObjectUtil;
|
|
#[Required]
|
|
public LlmProviderUtil $llmProviderUtil;
|
|
#[Required]
|
|
public TextPromptUtil $textPromptsUtil;
|
|
public ?PrometheusWrapper $prometheusWrapper;
|
|
|
|
#[Required]
|
|
public LoggerInterface $logger;
|
|
|
|
#[AdminRoute('/llm/prompt-action', methods: ['POST'])]
|
|
public function LLMTextAction(Request $request): JsonResponse
|
|
{
|
|
$data = json_decode($request->getContent(), true);
|
|
$inputPrompt = $this->preparePrompt($data);
|
|
$provider = $this->llmProviderUtil->getAdminModalProvider();
|
|
$response = $provider->getResponse($inputPrompt);
|
|
|
|
$this->logStats($provider->getLastResponseStats(), $inputPrompt, $response, $data);
|
|
|
|
return new JsonResponse(['value' => $response]);
|
|
}
|
|
|
|
#[AdminRoute('/llm/prompt-stream-action')]
|
|
public function LLMTextStreamAction(Request $request): StreamedResponse|Response
|
|
{
|
|
if ($request->isMethod(Request::METHOD_POST)) {
|
|
$data = json_decode($request->getContent(), true);
|
|
} else {
|
|
$data = json_decode(urldecode($request->query->get('payload')), true);
|
|
}
|
|
|
|
try {
|
|
$inputPrompt = $this->preparePrompt($data);
|
|
$provider = $this->llmProviderUtil->getAdminModalProvider();
|
|
$response = $provider->getStreamedResponse($inputPrompt);
|
|
} catch (\Throwable $e) {
|
|
return $this->handleException($e);
|
|
}
|
|
|
|
return new StreamedResponse(function () use ($response, $provider, $inputPrompt, $data) {
|
|
$all = '';
|
|
while (!$response->eof()) {
|
|
$part = $response->read(32);
|
|
$all .= $part;
|
|
echo $part;
|
|
ob_flush();
|
|
flush();
|
|
}
|
|
$this->logStats($provider->getLastResponseStats(), $inputPrompt, $all, $data);
|
|
});
|
|
}
|
|
|
|
#[AdminRoute('/llm/modal-data', methods: ['POST'])]
|
|
public function modalData(Request $request): JsonResponse|Response
|
|
{
|
|
$data = json_decode($request->getContent(), true);
|
|
try {
|
|
$prompt = $this->preparePrompt($data);
|
|
} catch (\Throwable $e) {
|
|
return $this->handleException($e);
|
|
}
|
|
|
|
$entity = $this->textPromptsUtil->getPromptEntityById($data['promptId']);
|
|
|
|
$allPrompts = array_values(array_map(fn ($entity) => [
|
|
'id' => $entity->getId(),
|
|
'title' => $entity->getTitle(),
|
|
], $this->textObjectUtil->getObjectLabelPrompts($data['objectLabel'])));
|
|
|
|
return new JsonResponse([
|
|
'title' => $entity->getTitle(),
|
|
'prompt' => $prompt->getBaseUserText(),
|
|
'prompts' => $allPrompts,
|
|
]);
|
|
}
|
|
|
|
protected function preparePrompt(array $data): TextPromptInput
|
|
{
|
|
$this->validateRequestBody($data, ['objectLabel', 'promptId']);
|
|
|
|
$label = $data['objectLabel'];
|
|
$promptId = $data['promptId'];
|
|
$text = $data['text'] ?? '';
|
|
$promptText = $data['userPrompt'] ?? null;
|
|
$entityType = $data['entityType'] ?? null;
|
|
$entityId = $data['entityId'] ?? null ?: null;
|
|
$prevAnswer = $data['prevAnswer'] ?? null;
|
|
$editPrompt = $data['editPrompt'] ?? null;
|
|
|
|
$prompt = $this->textPromptsUtil->getTextPromptInputById($promptId);
|
|
|
|
if ($promptText) {
|
|
$prompt->setBaseUserText($promptText);
|
|
}
|
|
|
|
$prompt->replacePlaceholder(PromptPlaceholder::TEXT, $text);
|
|
|
|
$textObject = $this->textObjectUtil->getByLabel($label);
|
|
$textObject->modifyPrompt($prompt, $entityType, $entityId);
|
|
|
|
if ($prevAnswer && $editPrompt) {
|
|
$prompt->addAssistantAnswer($prevAnswer)->addUserText($editPrompt);
|
|
}
|
|
|
|
return $prompt;
|
|
}
|
|
|
|
protected function validateRequestBody(mixed $data, array $requiredFields): void
|
|
{
|
|
$diff = array_diff($requiredFields, array_keys($data ?: []));
|
|
if (!empty($diff)) {
|
|
throw new BadRequestHttpException('missing required fields ['.implode(', ', $diff).']');
|
|
}
|
|
}
|
|
|
|
protected function handleException(\Throwable $e): Response
|
|
{
|
|
if ($e instanceof BadRequestHttpException) {
|
|
return new Response($e->getMessage(), $e->getStatusCode());
|
|
}
|
|
|
|
\Sentry\captureException($e);
|
|
|
|
return new Response('Chyba při generování odpovědi', 500);
|
|
}
|
|
|
|
private function logStats(?PromptResponseStats $stats, TextPromptInput $inputPrompt, string $response, mixed $data): void
|
|
{
|
|
if (!$stats) {
|
|
return;
|
|
}
|
|
|
|
$inputMessages = array_map(fn ($message) => [
|
|
'role' => $message->role->value,
|
|
'content' => $message->content,
|
|
], $inputPrompt->getMessages());
|
|
|
|
$this->logger->notice('LLM prompt: '.$inputPrompt->getBaseUserText(),
|
|
array_merge(
|
|
$stats->toArray(),
|
|
[
|
|
'inputMessages' => json_encode($inputMessages),
|
|
'response' => $response,
|
|
'objectLabel' => $data['objectLabel'] ?? null,
|
|
])
|
|
);
|
|
|
|
if ($this->prometheusWrapper) {
|
|
$labels = $stats->getLabels();
|
|
|
|
$this->prometheusWrapper->setCounter('kupshop', 'llm_prompts_tokens_input', 'Input tokens', $stats->getInputTokens(), $labels);
|
|
$this->prometheusWrapper->setCounter('kupshop', 'llm_prompts_tokens_output', 'Output tokens', $stats->getOutputTokens(), $labels);
|
|
$this->prometheusWrapper->setCounter('kupshop', 'llm_prompts_tokens_total', 'Total tokens', $stats->getTotalTokens(), $labels);
|
|
$this->prometheusWrapper->flush();
|
|
}
|
|
}
|
|
|
|
#[Required]
|
|
public function setPrometheusWrapper(?PrometheusWrapper $prometheusWrapper): void
|
|
{
|
|
$this->prometheusWrapper = $prometheusWrapper;
|
|
}
|
|
}
|