Add srt caller mode and stream encryption support. (#4088)

Add srt caller mode and stream encryption support.
1. Support srt caller mode, realize srt proxy pull stream proxy push
stream;
url parameter format such as: srt://127.0.0.1:9000?streamid=#!
::r=live/test11
2. Support srt stream encrypted transmission in caller and listener
mode.

---------

Co-authored-by: xiongguangjie <xiong_panda@163.com>
This commit is contained in:
baigao-X 2024-12-28 20:21:29 +08:00 committed by GitHub
parent cb4db80502
commit 1c8ed1c55a
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
27 changed files with 3002 additions and 17 deletions

View File

@ -463,6 +463,10 @@ if(ENABLE_VIDEOSTACK)
endif ()
endif ()
if(ENABLE_SRT)
update_cached_list(MK_COMPILE_DEFINITIONS ENABLE_SRT)
endif()
# ----------------------------------------------------------------------------
# Solution folders:
# ----------------------------------------------------------------------------

View File

@ -400,6 +400,8 @@ port=9000
latencyMul=4
#包缓存的大小
pktBufSize=8192
#srt udp服务器的密码,为空表示不加密
passPhrase=
[rtsp]

View File

@ -522,7 +522,7 @@
"response": []
},
{
"name": "添加rtsp/rtmp/hls拉流代理(addStreamProxy)",
"name": "添加rtsp/rtmp/hls/srt拉流代理(addStreamProxy)",
"request": {
"method": "GET",
"header": [],
@ -663,7 +663,19 @@
"value": null,
"description": "无人观看时,是否直接关闭(而不是通过on_none_reader hook返回close)",
"disabled": true
}
},
{
"key": "latency",
"value": null,
"description": "srt延时, 单位毫秒",
"disabled": true
},
{
"key": "passphrase",
"value": null,
"description": "srt拉流的密码",
"disabled": true
}
]
}
},
@ -753,7 +765,7 @@
"response": []
},
{
"name": "添加rtsp/rtmp推流(addStreamPusherProxy)",
"name": "添加rtsp/rtmp/srt推流(addStreamPusherProxy)",
"request": {
"method": "GET",
"header": [],
@ -815,7 +827,20 @@
"value": null,
"description": "推流重试次数,不传此参数或传值<=0时则无限重试",
"disabled": true
}
},
{
"key": "latency",
"value": null,
"description": "srt延时, 单位毫秒",
"disabled": true
},
{
"key": "passphrase",
"value": null,
"description": "srt推流的密码",
"disabled": true
}
]
}
},
@ -2610,4 +2635,4 @@
"value": "__defaultVhost__"
}
]
}
}

View File

@ -685,6 +685,7 @@ void addStreamPusherProxy(const string &schema,
int retry_count,
int rtp_type,
float timeout_sec,
const mINI &args,
const function<void(const SockException &ex, const string &key)> &cb) {
auto key = getPusherKey(schema, vhost, app, stream, url);
auto src = MediaSource::find(schema, vhost, app, stream);
@ -703,14 +704,20 @@ void addStreamPusherProxy(const string &schema,
// Add push stream proxy
auto pusher = s_pusher_proxy.make(key, src, retry_count);
// 先透传拷贝参数 [AUTO-TRANSLATED:22b5605e]
// First pass-through copy parameters
for (auto &pr : args) {
(*pusher)[pr.first] = pr.second;
}
// 指定RTP over TCP(播放rtsp时有效) [AUTO-TRANSLATED:1a062656]
// Specify RTP over TCP (effective when playing RTSP)
pusher->emplace(Client::kRtpType, rtp_type);
(*pusher)[Client::kRtpType] = rtp_type;
if (timeout_sec > 0.1f) {
// 推流握手超时时间 [AUTO-TRANSLATED:00762fc1]
// Push stream handshake timeout
pusher->emplace(Client::kTimeoutMS, timeout_sec * 1000);
(*pusher)[Client::kTimeoutMS] = timeout_sec * 1000;
}
// 开始推流,如果推流失败或者推流中止,将会自动重试若干次,默认一直重试 [AUTO-TRANSLATED:c8b95088]
@ -1174,6 +1181,12 @@ void installWebApi() {
api_regist("/index/api/addStreamPusherProxy", [](API_ARGS_MAP_ASYNC) {
CHECK_SECRET();
CHECK_ARGS("schema", "vhost", "app", "stream", "dst_url");
mINI args;
for (auto &pr : allArgs.args) {
args.emplace(pr.first, pr.second);
}
auto dst_url = allArgs["dst_url"];
auto retry_count = allArgs["retry_count"].empty() ? -1 : allArgs["retry_count"].as<int>();
addStreamPusherProxy(allArgs["schema"],
@ -1184,6 +1197,7 @@ void installWebApi() {
retry_count,
allArgs["rtp_type"],
allArgs["timeout_sec"],
args,
[invoker, val, headerOut, dst_url](const SockException &ex, const string &key) mutable {
if (ex) {
val["code"] = API::OtherFailed;

View File

@ -26,6 +26,14 @@ file(GLOB MediaKit_SRC_LIST
${CMAKE_CURRENT_SOURCE_DIR}/*/*.cpp
${CMAKE_CURRENT_SOURCE_DIR}/*/*.h)
if(NOT ENABLE_SRT)
file(GLOB SRT_SRC_LIST
${CMAKE_CURRENT_SOURCE_DIR}/Srt/*.c
${CMAKE_CURRENT_SOURCE_DIR}/Srt/*.cpp
${CMAKE_CURRENT_SOURCE_DIR}/Srt/*.h)
list(REMOVE_ITEM MediaKit_SRC_LIST ${SRT_SRC_LIST})
endif()
if(USE_SOLUTION_FOLDERS AND (NOT GROUP_BY_EXPLORER))
# IDE ,
set_file_group("${CMAKE_CURRENT_SOURCE_DIR}" ${MediaKit_SRC_LIST})
@ -49,6 +57,7 @@ target_link_libraries(zlmediakit
target_include_directories(zlmediakit
PRIVATE
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>"
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/..>"
PUBLIC
"$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}>")

View File

@ -396,6 +396,8 @@ const string kWaitTrackReady = "wait_track_ready";
const string kPlayTrack = "play_track";
const string kProxyUrl = "proxy_url";
const string kRtspSpeed = "rtsp_speed";
const string kLatency = "latency";
const string kPassPhrase = "passPhrase";
} // namespace Client
} // namespace mediakit

View File

@ -624,6 +624,10 @@ extern const std::string kProxyUrl;
// 设置开始rtsp倍速播放 [AUTO-TRANSLATED:5db03cad]
// Set the start RTSP playback speed
extern const std::string kRtspSpeed;
// Set SRT delay
extern const std::string kLatency;
// Set SRT PassPhrase
extern const std::string kPassPhrase;
} // namespace Client
} // namespace mediakit

View File

@ -15,6 +15,9 @@
#include "Rtmp/FlvPlayer.h"
#include "Http/HlsPlayer.h"
#include "Http/TsPlayerImp.h"
#ifdef ENABLE_SRT
#include "Srt/SrtPlayerImp.h"
#endif // ENABLE_SRT
using namespace std;
using namespace toolkit;
@ -70,6 +73,12 @@ PlayerBase::Ptr PlayerBase::createPlayer(const EventPoller::Ptr &in_poller, cons
}
}
#ifdef ENABLE_SRT
if (strcasecmp("srt", prefix.data()) == 0) {
return PlayerBase::Ptr(new SrtPlayerImp(poller), release_func);
}
#endif//ENABLE_SRT
throw std::invalid_argument("not supported play schema:" + url_in);
}
@ -78,6 +87,8 @@ PlayerBase::PlayerBase() {
this->mINI::operator[](Client::kMediaTimeoutMS) = 5000;
this->mINI::operator[](Client::kBeatIntervalMS) = 5000;
this->mINI::operator[](Client::kWaitTrackReady) = true;
this->mINI::operator[](Client::kLatency) = 0;
this->mINI::operator[](Client::kPassPhrase) = "";
}
} /* namespace mediakit */

View File

@ -12,6 +12,9 @@
#include "PusherBase.h"
#include "Rtsp/RtspPusher.h"
#include "Rtmp/RtmpPusher.h"
#ifdef ENABLE_SRT
#include "Srt/SrtPusher.h"
#endif // ENABLE_SRT
using namespace toolkit;
@ -50,6 +53,13 @@ PusherBase::Ptr PusherBase::createPusher(const EventPoller::Ptr &in_poller,
return PusherBase::Ptr(new RtmpPusherImp(poller, std::dynamic_pointer_cast<RtmpMediaSource>(src)), release_func);
}
#ifdef ENABLE_SRT
if (strcasecmp("srt", prefix.data()) == 0) {
return PusherBase::Ptr(new SrtPusherImp(poller, std::dynamic_pointer_cast<TSMediaSource>(src)), release_func);
}
#endif//ENABLE_SRT
throw std::invalid_argument("not supported push schema:" + url);
}

1047
src/Srt/SrtCaller.cpp Normal file

File diff suppressed because it is too large Load Diff

199
src/Srt/SrtCaller.h Normal file
View File

@ -0,0 +1,199 @@
/*
* Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved.
*
* This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit).
*
* Use of this source code is governed by MIT-like license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef ZLMEDIAKIT_SRTCALLER_H
#define ZLMEDIAKIT_SRTCALLER_H
//srt
#include "srt/Packet.hpp"
#include "srt/Crypto.hpp"
#include "srt/PacketQueue.hpp"
#include "srt/PacketSendQueue.hpp"
#include "srt/Statistic.hpp"
#include "Poller/EventPoller.h"
#include "Network/Socket.h"
#include "Poller/Timer.h"
#include "Util/TimeTicker.h"
#include "Common/MultiMediaSourceMuxer.h"
#include "Rtp/Decoder.h"
#include "TS/TSMediaSource.h"
#include <memory>
#include <string>
namespace mediakit {
// 解析srt 信令url的工具类
class SrtUrl {
public:
std::string _full_url;
std::string _params;
std::string _host;
uint16_t _port;
std::string _vhost;
std::string _app;
std::string _stream;
public:
void parse(const std::string &url);
};
// 实现了webrtc代理拉流功能
class SrtCaller : public std::enable_shared_from_this<SrtCaller>{
public:
using Ptr = std::shared_ptr<SrtCaller>;
using SteadyClock = std::chrono::steady_clock;
using TimePoint = std::chrono::time_point<SteadyClock>;
SrtCaller(const toolkit::EventPoller::Ptr &poller);
virtual ~SrtCaller();
const toolkit::EventPoller::Ptr &getPoller() const {return _poller;}
virtual void inputSockData(uint8_t *buf, int len, struct sockaddr *addr);
virtual void onSendTSData(const SRT::Buffer::Ptr &buffer, bool flush);
protected:
virtual void onConnect();
virtual void onHandShakeFinished();
virtual void onResult(const toolkit::SockException &ex);
virtual void onSRTData(SRT::DataPacket::Ptr pkt);
virtual uint16_t getLatency() = 0;
virtual int getLatencyMul();
virtual int getPktBufSize();
virtual float getTimeOutSec();
virtual bool isPlayer() = 0;
private:
void doHandshake();
void sendHandshakeInduction();
void sendHandshakeConclusion();
void sendACKPacket();
void sendLightACKPacket();
void sendNAKPacket(std::list<SRT::PacketQueue::LostPair> &lost_list);
void sendMsgDropReq(uint32_t first, uint32_t last);
void sendKeepLivePacket();
void sendShutDown();
void tryAnnounceKeyMaterial();
void sendControlPacket(SRT::ControlPacket::Ptr pkt, bool flush = true);
void sendDataPacket(SRT::DataPacket::Ptr pkt, char *buf, int len, bool flush = false);
void sendPacket(toolkit::Buffer::Ptr pkt, bool flush);
void handleHandshake(uint8_t *buf, int len, struct sockaddr *addr);
void handleHandshakeInduction(SRT::HandshakePacket &pkt, struct sockaddr *addr);
void handleHandshakeConclusion(SRT::HandshakePacket &pkt, struct sockaddr *addr);
void handleACK(uint8_t *buf, int len, struct sockaddr *addr);
void handleACKACK(uint8_t *buf, int len, struct sockaddr *addr);
void handleNAK(uint8_t *buf, int len, struct sockaddr *addr);
void handleDropReq(uint8_t *buf, int len, struct sockaddr *addr);
void handleKeeplive(uint8_t *buf, int len, struct sockaddr *addr);
void handleShutDown(uint8_t *buf, int len, struct sockaddr *addr);
void handlePeerError(uint8_t *buf, int len, struct sockaddr *addr);
void handleCongestionWarning(uint8_t *buf, int len, struct sockaddr *addr);
void handleUserDefinedType(uint8_t *buf, int len, struct sockaddr *addr);
void handleDataPacket(uint8_t *buf, int len, struct sockaddr *addr);
void handleKeyMaterialReqPacket(uint8_t *buf, int len, struct sockaddr *addr);
void handleKeyMaterialRspPacket(uint8_t *buf, int len, struct sockaddr *addr);
void checkAndSendAckNak();
void createTimerForCheckAlive();
std::string generateStreamId();
uint32_t generateSocketId();
int32_t generateInitSeq();
size_t getPayloadSize();
virtual std::string getPassphrase() = 0;
protected:
SrtUrl _url;
toolkit::EventPoller::Ptr _poller;
bool _is_handleshake_finished = false;
private:
toolkit::Socket::Ptr _socket;
TimePoint _now;
TimePoint _start_timestamp;
// for calculate rtt for delay
TimePoint _induction_ts;
//the initial value of RTT is 100 milliseconds
//RTTVar is 50 milliseconds
uint32_t _rtt = 100 * 1000;
uint32_t _rtt_variance = 50 * 1000;
//local
uint32_t _socket_id = 0;
uint32_t _init_seq_number = 0;
uint32_t _mtu = 1500;
uint32_t _max_flow_window_size = 8192;
uint16_t _delay = 120;
//peer
uint32_t _sync_cookie = 0;
uint32_t _peer_socket_id;
// for handshake
SRT::Timer::Ptr _handleshake_timer;
SRT::HandshakePacket::Ptr _handleshake_req;
// for keeplive
SRT::Ticker _send_ticker;
SRT::Timer::Ptr _keeplive_timer;
// for alive
SRT::Ticker _alive_ticker;
SRT::Timer::Ptr _alive_timer;
// for recv
SRT::PacketQueueInterface::Ptr _recv_buf;
uint32_t _last_pkt_seq = 0;
// Ack
SRT::UTicker _ack_ticker;
uint32_t _last_ack_pkt_seq = 0;
uint32_t _light_ack_pkt_count = 0;
uint32_t _ack_number_count = 0;
std::map<uint32_t, TimePoint> _ack_send_timestamp;
// Full Ack
// Link Capacity and Receiving Rate Estimation
std::shared_ptr<SRT::PacketRecvRateContext> _pkt_recv_rate_context;
std::shared_ptr<SRT::EstimatedLinkCapacityContext> _estimated_link_capacity_context;
// Nak
SRT::UTicker _nak_ticker;
//for Send
SRT::PacketSendQueue::Ptr _send_buf;
SRT::ResourcePool<SRT::BufferRaw> _packet_pool;
uint32_t _send_packet_seq_number = 0;
uint32_t _send_msg_number = 1;
//AckAck
uint32_t _last_recv_ackack_seq_num = 0;
// for encryption
SRT::Crypto::Ptr _crypto;
SRT::Timer::Ptr _announce_timer;
SRT::KeyMaterialPacket::Ptr _announce_req;
};
} /* namespace mediakit */
#endif /* ZLMEDIAKIT_SRTCALLER_H */

169
src/Srt/SrtPlayer.cpp Normal file
View File

@ -0,0 +1,169 @@
/*
* Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved.
*
* This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit).
*
* Use of this source code is governed by MIT-like license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "SrtPlayer.h"
#include "SrtPlayerImp.h"
#include "Common/config.h"
#include "Http/HlsPlayer.h"
using namespace toolkit;
using namespace std;
namespace mediakit {
SrtPlayer::SrtPlayer(const EventPoller::Ptr &poller)
: SrtCaller(poller) {
DebugL;
}
SrtPlayer::~SrtPlayer(void) {
DebugL;
}
void SrtPlayer::play(const string &strUrl) {
DebugL;
try {
_url.parse(strUrl);
} catch (std::exception &ex) {
onResult(SockException(Err_other, StrPrinter << "illegal srt url:" << ex.what()));
return;
}
onConnect();
return;
}
void SrtPlayer::teardown() {
SrtCaller::onResult(SockException(Err_other, StrPrinter << "teardown: " << _url._full_url));
}
void SrtPlayer::pause(bool bPause) {
DebugL;
}
void SrtPlayer::speed(float speed) {
DebugL;
}
void SrtPlayer::onHandShakeFinished() {
SrtCaller::onHandShakeFinished();
onResult(SockException(Err_success, "srt play success"));
}
void SrtPlayer::onResult(const SockException &ex) {
SrtCaller::onResult(ex);
if (!ex) {
// 播放成功
onPlayResult(ex);
_benchmark_mode = (*this)[Client::kBenchmarkMode].as<int>();
// 播放成功,恢复数据包接收超时定时器
_recv_ticker.resetTime();
auto timeout = getTimeOutSec();
//读取配置文件
weak_ptr<SrtPlayer> weakSelf = static_pointer_cast<SrtPlayer>(shared_from_this());
// 创建rtp数据接收超时检测定时器
_check_timer = std::make_shared<Timer>(timeout /2,
[weakSelf, timeout]() {
auto strongSelf = weakSelf.lock();
if (!strongSelf) {
return false;
}
if (strongSelf->_recv_ticker.elapsedTime() > timeout * 1000) {
// 接收媒体数据包超时
strongSelf->onResult(SockException(Err_timeout, "receive srt media data timeout:" + strongSelf->_url._full_url));
return false;
}
return true;
}, getPoller());
} else {
WarnL << ex.getErrCode() << " " << ex.what();
if (ex.getErrCode() == Err_shutdown) {
// 主动shutdown的不触发回调
return;
}
if (!_is_handleshake_finished) {
onPlayResult(ex);
} else {
onShutdown(ex);
}
}
return;
}
void SrtPlayer::onSRTData(SRT::DataPacket::Ptr pkt) {
_recv_ticker.resetTime();
}
uint16_t SrtPlayer::getLatency() {
auto latency = (*this)[Client::kLatency].as<uint16_t>();
return (uint16_t)latency ;
}
float SrtPlayer::getTimeOutSec() {
auto timeoutMS = (*this)[Client::kTimeoutMS].as<uint64_t>();
return (float)timeoutMS / (float)1000;
}
std::string SrtPlayer::getPassphrase() {
auto passPhrase = (*this)[Client::kPassPhrase].as<string>();
return passPhrase;
}
///////////////////////////////////////////////////
// SrtPlayerImp
void SrtPlayerImp::onPlayResult(const toolkit::SockException &ex) {
if (ex) {
Super::onPlayResult(ex);
}
//success result only occur when addTrackCompleted
return;
}
std::vector<Track::Ptr> SrtPlayerImp::getTracks(bool ready /*= true*/) const {
return _demuxer ? static_pointer_cast<HlsDemuxer>(_demuxer)->getTracks(ready) : Super::getTracks(ready);
}
void SrtPlayerImp::addTrackCompleted() {
Super::onPlayResult(toolkit::SockException(toolkit::Err_success, "play success"));
}
void SrtPlayerImp::onSRTData(SRT::DataPacket::Ptr pkt) {
SrtPlayer::onSRTData(pkt);
if (_benchmark_mode) {
return;
}
auto strong_self = shared_from_this();
if (!_demuxer) {
auto demuxer = std::make_shared<HlsDemuxer>();
demuxer->start(getPoller(), this);
_demuxer = std::move(demuxer);
}
if (!_decoder && _demuxer) {
_decoder = DecoderImp::createDecoder(DecoderImp::decoder_ts, _demuxer.get());
}
if (_decoder && _demuxer) {
_decoder->input(reinterpret_cast<const uint8_t *>(pkt->payloadData()), pkt->payloadSize());
}
return;
}
} /* namespace mediakit */

65
src/Srt/SrtPlayer.h Normal file
View File

@ -0,0 +1,65 @@
/*
* Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved.
*
* This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit).
*
* Use of this source code is governed by MIT-like license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef ZLMEDIAKIT_SRTPLAYER_H
#define ZLMEDIAKIT_SRTPLAYER_H
#include "Network/Socket.h"
#include "Player/PlayerBase.h"
#include "Poller/Timer.h"
#include "Util/TimeTicker.h"
#include "srt/SrtTransport.hpp"
#include "Http/HttpRequester.h"
#include <memory>
#include <string>
#include "SrtCaller.h"
namespace mediakit {
// 实现了srt代理拉流功能
class SrtPlayer
: public PlayerBase , public SrtCaller {
public:
using Ptr = std::shared_ptr<SrtPlayer>;
SrtPlayer(const toolkit::EventPoller::Ptr &poller);
~SrtPlayer() override;
//// PlayerBase override////
void play(const std::string &strUrl) override;
void teardown() override;
void pause(bool pause) override;
void speed(float speed) override;
protected:
//// SrtCaller override////
void onHandShakeFinished() override;
void onSRTData(SRT::DataPacket::Ptr pkt) override;
void onResult(const toolkit::SockException &ex) override;
bool isPlayer() override {return true;}
uint16_t getLatency() override;
float getTimeOutSec() override;
std::string getPassphrase() override;
protected:
//是否为性能测试模式
bool _benchmark_mode = false;
//超时功能实现
toolkit::Ticker _recv_ticker;
std::shared_ptr<toolkit::Timer> _check_timer;
};
} /* namespace mediakit */
#endif /* ZLMEDIAKIT_SRTPLAYER_H */

51
src/Srt/SrtPlayerImp.h Normal file
View File

@ -0,0 +1,51 @@
/*
* Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved.
*
* This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit).
*
* Use of this source code is governed by MIT-like license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef ZLMEDIAKIT_SRtPLAYERIMP_H
#define ZLMEDIAKIT_SRtPLAYERIMP_H
#include "SrtPlayer.h"
namespace mediakit {
class SrtPlayerImp
: public PlayerImp<SrtPlayer, PlayerBase>
, private TrackListener {
public:
using Ptr = std::shared_ptr<SrtPlayerImp>;
using Super = PlayerImp<SrtPlayer, PlayerBase>;
SrtPlayerImp(const toolkit::EventPoller::Ptr &poller) : Super(poller) {}
~SrtPlayerImp() override { DebugL; }
private:
//// SrtPlayer override////
void onSRTData(SRT::DataPacket::Ptr pkt) override;
//// PlayerBase override////
void onPlayResult(const toolkit::SockException &ex) override;
std::vector<Track::Ptr> getTracks(bool ready = true) const override;
private:
//// TrackListener override////
bool addTrack(const Track::Ptr &track) override { return true; }
void addTrackCompleted() override;
private:
// for player
DecoderImp::Ptr _decoder;
MediaSinkInterface::Ptr _demuxer;
// for pusher
TSMediaSource::RingType::RingReader::Ptr _ts_reader;
};
} /* namespace mediakit */
#endif /* ZLMEDIAKIT_SRtPLAYERIMP_H */

116
src/Srt/SrtPusher.cpp Normal file
View File

@ -0,0 +1,116 @@
/*
* Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved.
*
* This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit).
*
* Use of this source code is governed by MIT-like license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#include "SrtPusher.h"
#include "Common/config.h"
using namespace toolkit;
using namespace std;
namespace mediakit {
SrtPusher::SrtPusher(const EventPoller::Ptr &poller, const TSMediaSource::Ptr &src) : SrtCaller(poller) {
_push_src = src;
DebugL;
}
SrtPusher::~SrtPusher(void) {
DebugL;
}
void SrtPusher::publish(const string &strUrl) {
DebugL;
try {
_url.parse(strUrl);
} catch (std::exception &ex) {
onResult(SockException(Err_other, StrPrinter << "illegal srt url:" << ex.what()));
return;
}
onConnect();
return;
}
void SrtPusher::teardown() {
SrtCaller::onResult(SockException(Err_other, StrPrinter << "teardown: " << _url._full_url));
}
void SrtPusher::onHandShakeFinished() {
SrtCaller::onHandShakeFinished();
onResult(SockException(Err_success, "srt push success"));
doPublish();
}
void SrtPusher::onResult(const SockException &ex) {
SrtCaller::onResult(ex);
if (!ex) {
onPublishResult(ex);
} else {
WarnL << ex.getErrCode() << " " << ex.what();
if (ex.getErrCode() == Err_shutdown) {
// 主动shutdown的不触发回调
return;
}
if (!_is_handleshake_finished) {
onPublishResult(ex);
} else {
onShutdown(ex);
}
}
return;
}
uint16_t SrtPusher::getLatency() {
auto latency = (*this)[Client::kLatency].as<uint16_t>();
return (uint16_t)latency ;
}
float SrtPusher::getTimeOutSec() {
auto timeoutMS = (*this)[Client::kTimeoutMS].as<uint64_t>();
return (float)timeoutMS / (float)1000;
}
std::string SrtPusher::getPassphrase() {
auto passPhrase = (*this)[Client::kPassPhrase].as<string>();
return passPhrase;
}
void SrtPusher::doPublish() {
auto src = _push_src.lock();
if (!src) {
onResult(SockException(Err_eof, "the media source was released"));
return;
}
// 异步查找直播流
std::weak_ptr<SrtPusher> weak_self = static_pointer_cast<SrtPusher>(shared_from_this());
_ts_reader = src->getRing()->attach(getPoller());
_ts_reader->setDetachCB([weak_self]() {
auto strong_self = weak_self.lock();
if (!strong_self) {
// 本对象已经销毁
return;
}
strong_self->onShutdown(SockException(Err_shutdown));
});
_ts_reader->setReadCB([weak_self](const TSMediaSource::RingDataType &ts_list) {
auto strong_self = weak_self.lock();
if (!strong_self) {
// 本对象已经销毁
return;
}
size_t i = 0;
auto size = ts_list->size();
ts_list->for_each([&](const TSPacket::Ptr &ts) {
strong_self->onSendTSData(ts, ++i == size);
});
});
}
} /* namespace mediakit */

59
src/Srt/SrtPusher.h Normal file
View File

@ -0,0 +1,59 @@
/*
* Copyright (c) 2016-present The ZLMediaKit project authors. All Rights Reserved.
*
* This file is part of ZLMediaKit(https://github.com/ZLMediaKit/ZLMediaKit).
*
* Use of this source code is governed by MIT-like license that can be found in the
* LICENSE file in the root of the source tree. All contributing project authors
* may be found in the AUTHORS file in the root of the source tree.
*/
#ifndef ZLMEDIAKIT_SRTPUSHER_H
#define ZLMEDIAKIT_SRTPUSHER_H
#include "Network/Socket.h"
#include "Pusher/PusherBase.h"
#include "Poller/Timer.h"
#include "Util/TimeTicker.h"
#include "srt/SrtTransport.hpp"
#include "Http/HttpRequester.h"
#include <memory>
#include <string>
#include "SrtCaller.h"
namespace mediakit {
// 实现了srt代理推流功能
class SrtPusher
: public PusherBase , public SrtCaller {
public:
using Ptr = std::shared_ptr<SrtPusher>;
SrtPusher(const toolkit::EventPoller::Ptr &poller,const TSMediaSource::Ptr &src);
~SrtPusher() override;
//// PusherBase override////
void publish(const std::string &url) override;
void teardown() override;
void doPublish();
protected:
//// SrtCaller override////
void onHandShakeFinished() override;
void onResult(const toolkit::SockException &ex) override;
bool isPlayer() override {return false;}
uint16_t getLatency() override;
float getTimeOutSec() override;
std::string getPassphrase() override;
protected:
std::weak_ptr<TSMediaSource> _push_src;
TSMediaSource::RingType::RingReader::Ptr _ts_reader;
};
using SrtPusherImp = PusherImp<SrtPusher, PusherBase>;
} /* namespace mediakit */
#endif /* ZLMEDIAKIT_SRTPUSHER_H */

View File

@ -28,6 +28,10 @@ static inline int64_t DurationCountMicroseconds(SteadyClock::duration dur) {
return std::chrono::duration_cast<std::chrono::microseconds>(dur).count();
}
static inline uint32_t DurationCountSeconds(SteadyClock::duration dur) {
return std::chrono::duration_cast<std::chrono::seconds>(dur).count();
}
static inline uint32_t loadUint32(uint8_t *ptr) {
return ptr[0] << 24 | ptr[1] << 16 | ptr[2] << 8 | ptr[3];
}
@ -113,4 +117,4 @@ private:
} // namespace SRT
#endif // ZLMEDIAKIT_SRT_COMMON_H
#endif // ZLMEDIAKIT_SRT_COMMON_H

507
srt/Crypto.cpp Normal file
View File

@ -0,0 +1,507 @@
#include <atomic>
#include "Util/MD5.h"
#include "Util/logger.h"
#include "Crypto.hpp"
#if defined(ENABLE_OPENSSL)
#include "openssl/evp.h"
#endif
using namespace toolkit;
using namespace std;
using namespace SRT;
namespace SRT {
#if defined(ENABLE_OPENSSL)
inline const EVP_CIPHER* aes_key_len_mapping_wrap_cipher(int key_len) {
switch (key_len) {
case 192/8: return EVP_aes_192_wrap();
case 256/8: return EVP_aes_256_wrap();
case 128/8:
default:
return EVP_aes_128_wrap();
}
}
inline const EVP_CIPHER* aes_key_len_mapping_ctr_cipher(int key_len) {
switch (key_len) {
case 192/8: return EVP_aes_192_ctr();
case 256/8: return EVP_aes_256_ctr();
case 128/8:
default:
return EVP_aes_128_ctr();
}
}
#endif
/**
* @brief: aes_wrap
* @param [in]: in warp的数据
* @param [in]: in_len warp的数据长度
* @param [out]: out warp后输出的数据
* @param [out]: outLen
* @param [in]: key
* @param [in]: key_len
* @return : true: false:
**/
static bool aes_wrap(const uint8_t* in, int in_len, uint8_t* out, int* outLen, uint8_t* key, int key_len) {
#if defined(ENABLE_OPENSSL)
EVP_CIPHER_CTX* ctx = NULL;
*outLen = 0;
do {
if (!(ctx = EVP_CIPHER_CTX_new())) {
WarnL << "EVP_CIPHER_CTX_new fail";
break;
}
EVP_CIPHER_CTX_set_flags(ctx, EVP_CIPHER_CTX_FLAG_WRAP_ALLOW);
if (1 != EVP_EncryptInit_ex(ctx, aes_key_len_mapping_wrap_cipher(key_len), NULL, key, NULL)) {
WarnL << "EVP_EncryptInit_ex fail";
break;
}
int len1 = 0;
if (1 != EVP_EncryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) {
WarnL << "EVP_EncryptUpdate fail";
break;
}
int len2 = 0;
if (1 != EVP_EncryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) {
WarnL << "EVP_EncryptFinal_ex fail";
break;
}
*outLen = len1 + len2;
} while (0);
if (ctx != NULL) {
EVP_CIPHER_CTX_free(ctx);
}
return *outLen != 0;
#else
return false;
#endif
}
/**
* @brief: aes_unwrap
* @param [in]: in unwrap的数据
* @param [in]: in_len unwrap的数据长度
* @param [out]: out unwrap后输出的数据
* @param [out]: outLen unwrap后输出的数据长度
* @param [in]: key
* @param [in]: key_len
* @return : true: false:
**/
static bool aes_unwrap(const uint8_t* in, int in_len, uint8_t* out, int* outLen, uint8_t* key, int key_len) {
#if defined(ENABLE_OPENSSL)
EVP_CIPHER_CTX* ctx = NULL;
*outLen = 0;
do {
if (!(ctx = EVP_CIPHER_CTX_new())) {
WarnL << "EVP_CIPHER_CTX_new fail";
break;
}
EVP_CIPHER_CTX_set_flags(ctx, EVP_CIPHER_CTX_FLAG_WRAP_ALLOW);
if (1 != EVP_DecryptInit_ex(ctx, aes_key_len_mapping_wrap_cipher(key_len), NULL, key, NULL)) {
WarnL << "EVP_DecryptInit_ex fail";
break;
}
//设置pkcs7padding
if (1 != EVP_CIPHER_CTX_set_padding(ctx, 1)) {
WarnL << "EVP_CIPHER_CTX_set_padding fail";
break;
}
int len1 = 0;
if (1 != EVP_DecryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) {
WarnL << "EVP_DecryptUpdate fail";
break;
}
int len2 = 0;
if (1 != EVP_DecryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) {
WarnL << "EVP_DecryptFinal_ex fail";
break;
}
*outLen = len1 + len2;
} while (0);
if (ctx != NULL) {
EVP_CIPHER_CTX_free(ctx);
}
return *outLen != 0;
#else
return false;
#endif
}
/**
* @brief: aes ctr
* @param [in]: in
* @param [in]: in_len
* @param [out]: out
* @param [out]: outLen
* @param [in]: key
* @param [in]: key_len
* @param [in]: iv iv向量(16byte)
* @return : true: false:
**/
static bool aes_ctr_encrypt(const uint8_t* in, int in_len, uint8_t* out, int* outLen, uint8_t* key, int key_len, uint8_t* iv) {
#if defined(ENABLE_OPENSSL)
EVP_CIPHER_CTX* ctx = NULL;
*outLen = 0;
do {
if (!(ctx = EVP_CIPHER_CTX_new())) {
WarnL << "EVP_CIPHER_CTX_new fail";
break;
}
if (1 != EVP_EncryptInit_ex(ctx, aes_key_len_mapping_ctr_cipher(key_len), NULL, key, iv)) {
WarnL << "EVP_EncryptInit_ex fail";
break;
}
int len1 = 0;
if (1 != EVP_EncryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) {
WarnL << "EVP_EncryptUpdate fail";
break;
}
int len2 = 0;
if (1 != EVP_EncryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) {
WarnL << "EVP_EncryptFinal_ex fail";
break;
}
*outLen = len1 + len2;
} while (0);
if (ctx != NULL) {
EVP_CIPHER_CTX_free(ctx);
}
return *outLen != 0;
#else
return false;
#endif
}
/**
* @brief: aes ctr
* @param [in]: in
* @param [in]: in_len
* @param [out]: out
* @param [out]: outLen
* @param [in]: key
* @param [in]: key_len
* @param [in]: iv iv向量(16byte)
* @return : true: false:
**/
static bool aes_ctr_decrypt(const uint8_t* in, int in_len, uint8_t* out, int* outLen, uint8_t* key, int key_len, uint8_t* iv) {
#if defined(ENABLE_OPENSSL)
EVP_CIPHER_CTX* ctx = NULL;
*outLen = 0;
do {
if (!(ctx = EVP_CIPHER_CTX_new())) {
WarnL << "EVP_CIPHER_CTX_new fail";
break;
}
if (1 != EVP_DecryptInit_ex(ctx, aes_key_len_mapping_ctr_cipher(key_len), NULL, key, iv)) {
WarnL << "EVP_DecryptInit_ex fail";
break;
}
int len1 = 0;
if (1 != EVP_DecryptUpdate(ctx, (uint8_t*)out, &len1, (uint8_t*)in, in_len)) {
WarnL << "EVP_DecryptUpdate fail";
break;
}
int len2 = 0;
if (1 != EVP_DecryptFinal_ex(ctx, (uint8_t*)out + len1, &len2)) {
WarnL << "EVP_DecryptFinal_ex fail";
break;
}
*outLen = len1 + len2;
} while (0);
if (ctx != NULL) {
EVP_CIPHER_CTX_free(ctx);
}
return *outLen != 0;
#else
return false;
#endif
}
///////////////////////////////////////////////////
// CryptoContext
CryptoContext::CryptoContext(const std::string& passparase, uint8_t kk, KeyMaterial::Ptr packet) :
_passparase(passparase), _kk(kk) {
if (packet) {
loadFromKeyMaterial(packet);
} else {
refresh();
}
}
void CryptoContext::refresh() {
if (_salt.empty()) {
_salt = makeRandStr(_slen, false);
generateKEK();
}
_sek = makeRandStr(_klen, false);
return;
}
std::string CryptoContext::generateWarppedKey() {
string warpped_key;
int size = (_sek.size() + 15) /16 * 16 + 8;
warpped_key.resize(size);
auto res = aes_wrap((uint8_t*)_sek.data(), _sek.size(), (uint8_t*)warpped_key.data(), &size, (uint8_t*)_kek.data(), _kek.size());
if (!res) {
return "";
}
warpped_key.resize(size);
return warpped_key;
}
void CryptoContext::loadFromKeyMaterial(KeyMaterial::Ptr packet) {
_slen = packet->_slen;
_klen = packet->_klen;
_salt = packet->_salt;
generateKEK();
auto warpped_key = packet->_warpped_key;
BufferLikeString sek;
int size = warpped_key.size();
sek.resize(size);
auto ret = aes_unwrap((uint8_t*)warpped_key.data(), warpped_key.size(), (uint8_t*)sek.data(), &size, (uint8_t*)_kek.data(), _kek.size());
if (!ret) {
throw std::runtime_error(StrPrinter <<"warpped_key unwrap fail, password may mismatch");
}
sek.resize(size);
if (packet->_kk == KeyMaterial::KEY_BASED_ENCRYPTION_BOTH_SEK) {
if (_kk == KeyMaterial::KEY_BASED_ENCRYPTION_EVEN_SEK) {
_sek = sek.substr(0, _slen);
} else {
_sek = sek.substr(_slen, _slen);
}
} else {
_sek = sek;
}
return;
}
bool CryptoContext::generateKEK() {
/**
SEK = PRNG(KLen)
Salt = PRNG(128)
KEK = PBKDF2(passphrase, LSB(64,Salt), Iter, KLen)
**/
_kek.resize(_klen);
#if defined(ENABLE_OPENSSL)
if (PKCS5_PBKDF2_HMAC(_passparase.data(), _passparase.length(), (uint8_t*)_salt.data() + _slen - 64/8, 64 /8, _iter, EVP_sha1(), _klen, (uint8_t*)_kek.data()) != 1) {
return false;
}
return true;
#else
return false;
#endif
}
BufferLikeString::Ptr CryptoContext::generateIv(uint32_t pkt_seq_no) {
auto iv = std::make_shared<BufferLikeString>();
iv->resize(128 /8);
uint8_t* saltData = (uint8_t*)_salt.data();
uint8_t* ivData = (uint8_t*)iv->data();
memset((void*)ivData, 0, iv->size());
memcpy((void*)(ivData + 10), (void*)&pkt_seq_no, 4);
for (size_t i = 0; i < std::min<size_t>(_salt.size(), (size_t)112 /8); ++i) {
ivData[i] ^= saltData[i];
}
return iv;
}
///////////////////////////////////////////////////
// AesCtrCryptoContext
AesCtrCryptoContext::AesCtrCryptoContext(const std::string& passparase, uint8_t kk, KeyMaterial::Ptr packet) :
CryptoContext(passparase, kk, packet) {
}
BufferLikeString::Ptr AesCtrCryptoContext::encrypt(uint32_t pkt_seq_no, const char *buf, int len) {
auto iv = generateIv(htonl(pkt_seq_no));
auto payload = std::make_shared<BufferLikeString>();
int size = (len + 15) /16 * 16 + 8;
payload->resize(size);
auto ret = aes_ctr_encrypt((const uint8_t*)buf, len, (uint8_t*)payload->data(), &size, (uint8_t*)_sek.data(), _sek.size(), (uint8_t*)iv->data());
if (!ret) {
return nullptr;
}
payload->resize(size);
return payload;
}
BufferLikeString::Ptr AesCtrCryptoContext::decrypt(uint32_t pkt_seq_no, const char *buf, int len) {
auto iv = generateIv(htonl(pkt_seq_no));
auto payload = std::make_shared<BufferLikeString>();
int size = len;
payload->resize(size);
auto ret = aes_ctr_decrypt((const uint8_t*)buf, len, (uint8_t*)payload->data(), &size, (uint8_t*)_sek.data(), _sek.size(), (uint8_t*)iv->data());
if (!ret) {
return nullptr;
}
payload->resize(size);
return payload;
}
///////////////////////////////////////////////////
// Crypto
Crypto::Crypto(const std::string& passparase) :
_passparase(passparase) {
#ifndef ENABLE_OPENSSL
throw std::invalid_argument("openssl disable, please set ENABLE_OPENSSL when compile");
#endif
_ctx_pair[0] = createCtx(KeyMaterial::CIPHER_AES_CTR, _passparase, KeyMaterial::KEY_BASED_ENCRYPTION_EVEN_SEK);
_ctx_pair[1] = createCtx(KeyMaterial::CIPHER_AES_CTR, _passparase, KeyMaterial::KEY_BASED_ENCRYPTION_ODD_SEK);
_ctx_idx = 0;
}
CryptoContext::Ptr Crypto::createCtx(int cipher, const std::string& passparase, uint8_t kk, KeyMaterial::Ptr packet) {
switch (cipher){
case KeyMaterial::CIPHER_AES_CTR:
return std::make_shared<AesCtrCryptoContext>(passparase, kk, packet);
case KeyMaterial::CIPHER_AES_ECB:
case KeyMaterial::CIPHER_AES_CBC:
case KeyMaterial::CIPHER_AES_GCM:
default:
throw std::runtime_error(StrPrinter <<"not support cipher " << cipher);
}
}
HSExtKeyMaterial::Ptr Crypto::generateKeyMaterialExt(uint16_t extension_type) {
HSExtKeyMaterial::Ptr ext = std::make_shared<HSExtKeyMaterial>();
ext->extension_type = extension_type;
ext->_kk = _ctx_pair[_ctx_idx]->_kk;
ext->_cipher = _ctx_pair[_ctx_idx]->getCipher();
ext->_slen = _ctx_pair[_ctx_idx]->_slen;
ext->_klen = _ctx_pair[_ctx_idx]->_klen;
ext->_salt = _ctx_pair[_ctx_idx]->_salt;
ext->_warpped_key = _ctx_pair[_ctx_idx]->generateWarppedKey();
return ext;
}
KeyMaterialPacket::Ptr Crypto::generateAnnouncePacket(CryptoContext::Ptr ctx) {
KeyMaterialPacket::Ptr pkt = std::make_shared<KeyMaterialPacket>();
pkt->sub_type = HSExt::SRT_CMD_KMREQ;
pkt->_kk = ctx->_kk;
pkt->_cipher = ctx->getCipher();
pkt->_slen = ctx->_slen;
pkt->_klen = ctx->_klen;
pkt->_salt = ctx->_salt;
pkt->_warpped_key = ctx->generateWarppedKey();
return pkt;
}
KeyMaterialPacket::Ptr Crypto::takeAwayAnnouncePacket() {
auto pkt = _re_announce_pkt;
_re_announce_pkt = nullptr;
return pkt;
}
bool Crypto::loadFromKeyMaterial(KeyMaterial::Ptr packet) {
try {
if (packet->_kk == KeyMaterial::KEY_BASED_ENCRYPTION_EVEN_SEK) {
_ctx_pair[0] = createCtx(packet->_cipher, _passparase, packet->_kk, packet);
} else if (packet->_kk == KeyMaterial::KEY_BASED_ENCRYPTION_ODD_SEK) {
_ctx_pair[1] = createCtx(packet->_cipher, _passparase, packet->_kk, packet);
} else if (packet->_kk == KeyMaterial::KEY_BASED_ENCRYPTION_BOTH_SEK) {
_ctx_pair[0] = createCtx(packet->_cipher, _passparase, KeyMaterial::KEY_BASED_ENCRYPTION_EVEN_SEK, packet);
_ctx_pair[1] = createCtx(packet->_cipher, _passparase, KeyMaterial::KEY_BASED_ENCRYPTION_ODD_SEK, packet);
}
} catch (std::exception &ex) {
WarnL << ex.what();
return false;
}
return true;
}
BufferLikeString::Ptr Crypto::encrypt(DataPacket::Ptr pkt, const char *buf, int len) {
_pkt_count++;
//refresh
if (_pkt_count == _re_announcement_period) {
auto ctx = createCtx(KeyMaterial::CIPHER_AES_CTR, _passparase, _ctx_pair[!_ctx_idx]->_kk);
_ctx_pair[!_ctx_idx] = ctx;
_re_announce_pkt = generateAnnouncePacket(ctx);
}
if (_pkt_count > _refresh_period) {
_pkt_count = 0;
_ctx_idx = !_ctx_idx;
}
pkt->KK = _ctx_pair[_ctx_idx]->_kk;
return _ctx_pair[_ctx_idx]->encrypt(pkt->packet_seq_number, buf, len);
}
BufferLikeString::Ptr Crypto::decrypt(DataPacket::Ptr pkt, const char *buf, int len) {
CryptoContext::Ptr _ctx;
if (pkt->KK == KeyMaterial::KEY_BASED_ENCRYPTION_NO_SEK) {
auto payload = std::make_shared<BufferLikeString>();
payload->assign(buf, len);
return payload;
} else if (pkt->KK == KeyMaterial::KEY_BASED_ENCRYPTION_EVEN_SEK) {
_ctx = _ctx_pair[0];
} else if (pkt->KK == KeyMaterial::KEY_BASED_ENCRYPTION_ODD_SEK) {
_ctx = _ctx_pair[1];
}
if (!_ctx) {
WarnL << "not has effective KeyMaterial with kk: " << pkt->KK;
return nullptr;
}
return _ctx->decrypt(pkt->packet_seq_number, buf, len);
}
} // namespace SRT

102
srt/Crypto.hpp Normal file
View File

@ -0,0 +1,102 @@
#ifndef ZLMEDIAKIT_SRT_CRYPTO_H
#define ZLMEDIAKIT_SRT_CRYPTO_H
#include <stdint.h>
#include <vector>
#include "Network/Buffer.h"
#include "Network/sockutil.h"
#include "Util/logger.h"
#include "Common.hpp"
#include "HSExt.hpp"
#include "Packet.hpp"
namespace SRT {
class CryptoContext : public std::enable_shared_from_this<CryptoContext> {
public:
using Ptr = std::shared_ptr<CryptoContext>;
CryptoContext(const std::string& passparase, uint8_t kk, KeyMaterial::Ptr packet = nullptr);
virtual ~CryptoContext() = default;
virtual void refresh();
virtual std::string generateWarppedKey();
virtual BufferLikeString::Ptr encrypt(uint32_t pkt_seq_no, const char *buf, int len) = 0;
virtual BufferLikeString::Ptr decrypt(uint32_t pkt_seq_no, const char *buf, int len) = 0;
virtual uint8_t getCipher() const = 0;
protected:
virtual void loadFromKeyMaterial(KeyMaterial::Ptr packet);
virtual bool generateKEK();
BufferLikeString::Ptr generateIv(uint32_t pkt_seq_no);
private:
public:
std::string _passparase;
uint8_t _kk = SRT::KeyMaterial::KEY_BASED_ENCRYPTION_EVEN_SEK;
BufferLikeString _kek;
const uint32_t _iter = 2048;
size_t _slen = 16;
BufferLikeString _salt;
size_t _klen = 16;
BufferLikeString _sek;
};
class AesCtrCryptoContext : public CryptoContext {
public:
using Ptr = std::shared_ptr<AesCtrCryptoContext>;
AesCtrCryptoContext(const std::string& passparase, uint8_t kk, KeyMaterial::Ptr packet = nullptr);
virtual ~AesCtrCryptoContext() = default;
uint8_t getCipher() const override {
return KeyMaterial::CIPHER_AES_CTR;
}
BufferLikeString::Ptr encrypt(uint32_t pkt_seq_no, const char *buf, int len) override;
BufferLikeString::Ptr decrypt(uint32_t pkt_seq_no, const char *buf, int len) override;
};
class Crypto : public std::enable_shared_from_this<Crypto>{
public:
using Ptr = std::shared_ptr<Crypto>;
Crypto(const std::string& passparase);
virtual ~Crypto() = default;
HSExtKeyMaterial::Ptr generateKeyMaterialExt(uint16_t extension_type);
KeyMaterialPacket::Ptr takeAwayAnnouncePacket();
bool loadFromKeyMaterial(KeyMaterial::Ptr packet);
// for encryption
std::string _passparase;
//The recommended KM Refresh Period is after 2^25 packets encrypted with the same SEK are sent.
const uint32_t _refresh_period = 1 <<25;
const uint32_t _re_announcement_period = (1 <<25) - 4000;
uint32_t _pkt_count = 0;
KeyMaterialPacket::Ptr _re_announce_pkt;
CryptoContext::Ptr _ctx_pair[2]; /* Even(0)/Odd(1) crypto contexts */
uint32_t _ctx_idx = 0;
BufferLikeString::Ptr encrypt(DataPacket::Ptr pkt, const char *buf, int len);
BufferLikeString::Ptr decrypt(DataPacket::Ptr pkt, const char *buf, int len);
private:
CryptoContext::Ptr createCtx(int cipher, const std::string& passparase, uint8_t kk, KeyMaterial::Ptr packet = nullptr);
KeyMaterialPacket::Ptr generateAnnouncePacket(CryptoContext::Ptr ctx);
};
} // namespace SRT
#endif // ZLMEDIAKIT_SRT_CRYPTO_H

View File

@ -131,4 +131,162 @@ std::string HSExtStreamID::dump() {
return std::move(printer);
}
} // namespace SRT
size_t KeyMaterial::getContentSize() {
size_t variable_width = _slen + _warpped_key.size();
size_t content_size = variable_width + 16;
return content_size;
}
bool KeyMaterial::loadFromData(uint8_t *buf, size_t len) {
if (buf == NULL || len < 16) {
return false;
}
uint8_t *ptr = (uint8_t *)buf;
_km_version = (*ptr & 0x70) >> 4;
_pt = *ptr & 0x0f;
ptr += 1;
_sign = loadUint16(ptr);
ptr += 2;
_kk = *ptr & 0x03;
auto sek_num = 1;
if (_kk == KEY_BASED_ENCRYPTION_BOTH_SEK) {
sek_num = 2;
}
ptr += 1;
_keki = loadUint32(ptr);
ptr += 4;
_cipher = *ptr;
ptr += 1;
_auth = *ptr;
ptr += 1;
_se = *ptr;
ptr += 1;
//Resv2
ptr += 1;
//Resv3
ptr += 2;
_slen = *ptr *4;
ptr += 1;
_klen = *ptr *4;
ptr += 1;
size_t wrapped_key_len = 8 + sek_num * _klen;
size_t variable_width = _slen + wrapped_key_len;
if (len < variable_width + 16) {
return false;
}
_salt.assign((const char*)ptr, (size_t)_slen);
ptr += _slen;
_warpped_key.assign((const char*)ptr, (size_t)wrapped_key_len);
return true;
}
bool KeyMaterial::storeToData(uint8_t *buf, size_t len) {
auto content_size = getContentSize();
if (len < content_size) {
return false;
}
uint8_t *ptr = (uint8_t *)buf;
memset(ptr, 0, len);
*ptr = ((_km_version << 4)& 0x70) | (_pt & 0x0f);
ptr += 1;
storeUint16(ptr, _sign);
ptr += 2;
*ptr = _kk & 0x03;
ptr += 1;
storeUint32(ptr, _keki);
ptr += 4;
*ptr = _cipher;
ptr += 1;
*ptr = _auth;
ptr += 1;
*ptr = _se;
ptr += 1;
*ptr = 0; //Resv2
ptr += 1;
storeUint16(ptr, 0);//Resv3
ptr += 2;
*ptr = (uint8_t)(_slen/4);
ptr += 1;
*ptr = (uint8_t)(_klen/4);
ptr += 1;
const char *src = _salt.data();
for (size_t i = 0; i < _salt.size(); ptr++, src++, i++) {
*ptr = *src;
}
src = _warpped_key.data();
for (size_t i = 0; i < _warpped_key.size(); ptr++, src++, i++) {
*ptr = *src;
}
return true;
}
std::string KeyMaterial::dump() {
_StrPrinter printer;
printer << "kmVersion: " << _km_version
<< " pt : " << _pt
<< " sign : " << std::hex << _sign
<< " kk : " << _kk
<< " keki : " << _keki
<< " cipher : " << _cipher
<< " auth : " << _auth
<< " se : " << _se
<< " sLen : " << _slen
<< " salt : " << std::hex << _salt.data()
<< " kLen : " << _klen;
return std::move(printer);
}
bool HSExtKeyMaterial::loadFromData(uint8_t *buf, size_t len) {
if (buf == NULL || len < 4) {
return false;
}
HSExt::_data = BufferRaw::create();
HSExt::_data->assign((char *)buf, len);
HSExt::loadHeader();
assert(extension_type == SRT_CMD_KMREQ || extension_type == SRT_CMD_KMRSP);
return KeyMaterial::loadFromData(buf +4, len -4);
}
bool HSExtKeyMaterial::storeToData() {
size_t content_size = ((KeyMaterial::getContentSize() + 4) + 3) / 4 * 4;
HSExt::_data = BufferRaw::create();
HSExt::_data->setCapacity(content_size);
HSExt::_data->setSize(content_size);
extension_length = (content_size - 4) / 4;
HSExt::storeHeader();
return KeyMaterial::storeToData((uint8_t*)_data->data() + 4, content_size - 4);
}
std::string HSExtKeyMaterial::dump() {
return KeyMaterial::dump();
}
} // namespace SRT

View File

@ -125,5 +125,118 @@ public:
std::string dump() override;
std::string streamid;
};
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
|S| V | PT | Sign | Resv1 | KK|
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| KEKI |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Cipher | Auth | SE | Resv2 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Resv3 | SLen/4 | KLen/4 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Salt |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| |
+ Wrapped Key +
| |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 11: Key Material Message structure
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-key-material
*/
class KeyMaterial {
public:
using Ptr = std::shared_ptr<KeyMaterial>;
KeyMaterial() = default;
virtual ~KeyMaterial() = default;
bool loadFromData(uint8_t *buf, size_t len);
bool storeToData(uint8_t *buf, size_t len);
std::string dump();
protected:
size_t getContentSize();
public:
enum {
PACKET_TYPE_RESERVED = 0b0000,
PACKET_TYPE_MSMSG = 0b0001, // 1-Media Strem Message
PACKET_TYPE_KMMSG = 0b0010, // 2-Keying Material Message
PACKET_TYPE_MPEG_TS = 0b0111, // 7-MPEG-TS packet
};
enum {
KEY_BASED_ENCRYPTION_NO_SEK = 0b00,
KEY_BASED_ENCRYPTION_EVEN_SEK = 0b01,
KEY_BASED_ENCRYPTION_ODD_SEK = 0b10,
KEY_BASED_ENCRYPTION_BOTH_SEK = 0b11,
};
enum {
CIPHER_NONE = 0x00,
CIPHER_AES_ECB = 0x01, //reserved, not support
CIPHER_AES_CTR = 0x02,
CIPHER_AES_CBC = 0x03, //reserved, not support
CIPHER_AES_GCM = 0x04
};
enum {
AUTHENTICATION_NONE = 0x00,
AUTH_AES_GCM = 0x01,
};
enum {
STREAM_ENCAPSUALTION_UNSPECIFIED = 0x00,
STREAM_ENCAPSUALTION_MPEG_TS_UDP = 0x01,
STREAM_ENCAPSUALTION_MPEG_TS_SRT = 0x02,
};
uint8_t _km_version = 0b001;
uint8_t _pt = PACKET_TYPE_KMMSG;
uint16_t _sign = 0x2029;
uint8_t _kk = KEY_BASED_ENCRYPTION_EVEN_SEK;
uint32_t _keki = 0;
uint8_t _cipher = CIPHER_AES_CTR;
uint8_t _auth = AUTHENTICATION_NONE;
uint8_t _se = STREAM_ENCAPSUALTION_MPEG_TS_SRT;
uint16_t _slen = 16;
uint16_t _klen = 16;
BufferLikeString _salt;
BufferLikeString _warpped_key;
};
class HSExtKeyMaterial : public HSExt, public KeyMaterial {
public:
using Ptr = std::shared_ptr<HSExtKeyMaterial>;
HSExtKeyMaterial() = default;
virtual ~HSExtKeyMaterial() = default;
bool loadFromData(uint8_t *buf, size_t len) override;
bool storeToData() override;
std::string dump() override;
};
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| KM State |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
Figure 7: KM Response Error
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-key-material-extension-mess
*/
class HSExtKMResponseError : public HSExt {
public:
using Ptr = std::shared_ptr<HSExtKMResponseError>;
HSExtKMResponseError() = default;
~HSExtKMResponseError() = default;
bool loadFromData(uint8_t *buf, size_t len) override;
bool storeToData() override;
std::string dump() override;
};
} // namespace SRT
#endif // ZLMEDIAKIT_SRT_HS_EXT_H
#endif // ZLMEDIAKIT_SRT_HS_EXT_H

View File

@ -55,6 +55,13 @@ bool DataPacket::loadFromData(uint8_t *buf, size_t len) {
return true;
}
bool DataPacket::reloadPayload(uint8_t *buf, size_t len) {
_data->setCapacity(len + HEADER_SIZE);
_data->setSize(len + HEADER_SIZE);
memcpy(_data->data() + HEADER_SIZE, buf, len);
return true;
}
bool DataPacket::storeToHeader() {
if (!_data || _data->size() < HEADER_SIZE) {
WarnL << "data size less " << HEADER_SIZE;
@ -162,6 +169,12 @@ uint16_t ControlPacket::getControlType(uint8_t *buf, size_t len) {
return control_type;
}
uint16_t ControlPacket::getSubType(uint8_t *buf, size_t len) {
uint8_t *ptr = buf;
uint16_t subtype = loadUint16(ptr + 2);
return subtype;
}
bool ControlPacket::loadHeader() {
uint8_t *ptr = (uint8_t *)_data->data();
f = ptr[0] >> 7;
@ -225,6 +238,20 @@ size_t ControlPacket::size() const {
uint32_t ControlPacket::getSocketID(uint8_t *buf, size_t len) {
return loadUint32(buf + 12);
}
#define XX(name, value, str) {str, name},
std::map<std::string, SRT_REJECT_REASON> reject_map = {REJ_MAP(XX)};
#undef XX
std::string getRejectReason(SRT_REJECT_REASON code) {
switch (code) {
#define XX(name, value, str) case name : return str;
REJ_MAP(XX)
#undef XX
default : return "invalid";
}
}
std::string HandshakePacket::dump(){
_StrPrinter printer;
printer <<"flag:"<< (int)f<<"\r\n";
@ -324,6 +351,9 @@ bool HandshakePacket::loadExtMessage(uint8_t *buf, size_t len) {
case HSExt::SRT_CMD_HSREQ:
case HSExt::SRT_CMD_HSRSP: ext = std::make_shared<HSExtMessage>(); break;
case HSExt::SRT_CMD_SID: ext = std::make_shared<HSExtStreamID>(); break;
case HSExt::SRT_CMD_KMREQ:
case HSExt::SRT_CMD_KMRSP:
ext = std::make_shared<HSExtKeyMaterial>(); break;
default: WarnL << "not support ext " << type; break;
}
if (ext) {
@ -451,6 +481,23 @@ void HandshakePacket::assignPeerIP(struct sockaddr_storage *addr) {
}
}
void HandshakePacket::assignPeerIPBE(struct sockaddr_storage *addr) {
memset(peer_ip_addr, 0, sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0]));
if (addr->ss_family == AF_INET) {
struct sockaddr_in *ipv4 = (struct sockaddr_in *)addr;
storeUint32(peer_ip_addr, ipv4->sin_addr.s_addr);
} else if (addr->ss_family == AF_INET6) {
if (IN6_IS_ADDR_V4MAPPED(&((struct sockaddr_in6 *)addr)->sin6_addr)) {
struct in_addr addr4;
memcpy(&addr4, 12 + (char *)&(((struct sockaddr_in6 *)addr)->sin6_addr), 4);
storeUint32(peer_ip_addr, addr4.s_addr);
} else {
const sockaddr_in6 *ipv6 = (struct sockaddr_in6 *)addr;
memcpy(peer_ip_addr, ipv6->sin6_addr.s6_addr, sizeof(peer_ip_addr) * sizeof(peer_ip_addr[0]));
}
}
}
uint32_t HandshakePacket::generateSynCookie(
struct sockaddr_storage *addr, TimePoint ts, uint32_t current_cookie, int correction) {
static std::atomic<uint32_t> distractor { 0 };
@ -619,4 +666,4 @@ bool MsgDropReqPacket::storeToData() {
ptr += 4;
return true;
}
} // namespace SRT
} // namespace SRT

View File

@ -57,6 +57,7 @@ public:
static bool isDataPacket(uint8_t *buf, size_t len);
static uint32_t getSocketID(uint8_t *buf, size_t len);
bool loadFromData(uint8_t *buf, size_t len);
bool reloadPayload(uint8_t *buf, size_t len);
bool storeToData(uint8_t *buf, size_t len);
bool storeToHeader();
@ -105,6 +106,7 @@ public:
static const size_t HEADER_SIZE = 16;
static bool isControlPacket(uint8_t *buf, size_t len);
static uint16_t getControlType(uint8_t *buf, size_t len);
static uint16_t getSubType(uint8_t *buf, size_t len);
static uint32_t getSocketID(uint8_t *buf, size_t len);
ControlPacket() = default;
@ -180,6 +182,37 @@ protected:
Figure 5: Handshake packet structure
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-handshake
*/
// REJ code,from libsrt
#define REJ_MAP(XX) \
XX(SRT_REJ_UNKNOWN, 1000, "Unknown or erroneous") \
XX(SRT_REJ_SYSTEM, 1001, "Error in system calls") \
XX(SRT_REJ_PEER, 1002, "Peer rejected connection") \
XX(SRT_REJ_RESOURCE, 1003, "Resource allocation failure") \
XX(SRT_REJ_ROGUE, 1004, "Rogue peer or incorrect parameters") \
XX(SRT_REJ_BACKLOG, 1005, "Listener's backlog exceeded") \
XX(SRT_REJ_IPE, 1006, "Internal Program Error") \
XX(SRT_REJ_CLOSE, 1007, "Socket is being closed") \
XX(SRT_REJ_VERSION, 1008, "Peer version too old") \
XX(SRT_REJ_RDVCOOKIE, 1009, "Rendezvous-mode cookie collision") \
XX(SRT_REJ_BADSECRET, 1010, "Incorrect passphrase") \
XX(SRT_REJ_UNSECURE, 1011, "Password required or unexpected") \
XX(SRT_REJ_MESSAGEAPI, 1012, "MessageAPI/StreamAPI collision") \
XX(SRT_REJ_CONGESTION, 1013, "Congestion controller type collision") \
XX(SRT_REJ_FILTER, 1014, "Packet Filter settings error") \
XX(SRT_REJ_GROUP, 1015, "Group settings collision") \
XX(SRT_REJ_TIMEOUT, 1016, "Connection timeout") \
XX(SRT_REJ_CRYPTO, 1017, "Crypto mode")
typedef enum {
#define XX(name, value, str) name = value,
REJ_MAP(XX)
#undef XX
SRT_REJ_E_SIZE
} SRT_REJECT_REASON;
std::string getRejectReason(SRT_REJECT_REASON code);
class HandshakePacket : public ControlPacket {
public:
using Ptr = std::shared_ptr<HandshakePacket>;
@ -205,6 +238,10 @@ public:
generateSynCookie(struct sockaddr_storage *addr, TimePoint ts, uint32_t current_cookie = 0, int correction = 0);
std::string dump();
void assignPeerIP(struct sockaddr_storage *addr);
void assignPeerIPBE(struct sockaddr_storage *addr);
bool isReject() {
return (handshake_type >= SRT_REJ_UNKNOWN && handshake_type < SRT_REJ_E_SIZE);
}
///////ControlPacket override///////
bool loadFromData(uint8_t *buf, size_t len) override;
bool storeToData() override;
@ -367,6 +404,56 @@ public:
}
};
/*
0 1 2 3
0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1
+-+-+-+-+-+-+-+-+-+-+-+-+- SRT Header +-+-+-+-+-+-+-+-+-+-+-+-+-+
|1| Control Type = 0x7FFF | Subtype = 3/4 |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Type-specific Information |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Timestamp |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
| Destination Socket ID |
+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+
the Control Type field of the SRT packet header is set to User-Defined Type (see Table 1),
the Subtype field of the header is set to SRT_CMD_KMREQ for key-refresh request
and SRT_CMD_KMRSP for key-refresh response (Table 5). The KM Refresh mechanism is described in Section 6.1.6.
https://haivision.github.io/srt-rfc/draft-sharabayko-srt.html#name-key-material
*/
class KeyMaterialPacket : public ControlPacket, public KeyMaterial {
public:
using Ptr = std::shared_ptr<KeyMaterialPacket>;
KeyMaterialPacket() = default;
~KeyMaterialPacket() = default;
///////ControlPacket override///////
bool loadFromData(uint8_t *buf, size_t len) override {
if (len < HEADER_SIZE) {
WarnL << "data size" << len << " less " << HEADER_SIZE;
return false;
}
_data = BufferRaw::create();
_data->assign((char *)buf, len);
loadHeader();
assert(sub_type == HSExt::SRT_CMD_KMREQ || sub_type == HSExt::SRT_CMD_KMRSP);
return KeyMaterial::loadFromData(buf + HEADER_SIZE, len - HEADER_SIZE);
}
bool storeToData() override {
size_t content_size = ((KeyMaterial::getContentSize() + HEADER_SIZE) + 3) / 4 * 4;
control_type = ControlPacket::USERDEFINEDTYPE;
/* sub_type = HSExt::SRT_CMD_KMREQ; */
/* sub_type = HSExt::SRT_CMD_KMRSP; */
_data = BufferRaw::create();
_data->setCapacity(content_size);
_data->setSize(content_size);
storeToHeader();
return KeyMaterial::storeToData((uint8_t*)_data->data() + HEADER_SIZE, content_size - HEADER_SIZE);
}
};
} // namespace SRT
#endif // ZLMEDIAKIT_SRT_PACKET_H
#endif // ZLMEDIAKIT_SRT_PACKET_H

View File

@ -18,12 +18,14 @@ const std::string kTimeOutSec = SRT_FIELD "timeoutSec";
const std::string kPort = SRT_FIELD "port";
const std::string kLatencyMul = SRT_FIELD "latencyMul";
const std::string kPktBufSize = SRT_FIELD "pktBufSize";
const std::string kPassPhrase = SRT_FIELD "passPhrase";
static onceToken token([]() {
mINI::Instance()[kTimeOutSec] = 5;
mINI::Instance()[kPort] = 9000;
mINI::Instance()[kLatencyMul] = 4;
mINI::Instance()[kPktBufSize] = 8192;
mINI::Instance()[kPassPhrase] = "";
});
static std::atomic<uint32_t> s_srt_socket_id_generate { 125 };
@ -228,6 +230,8 @@ void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockad
// first
HSExtMessage::Ptr req;
HSExtStreamID::Ptr sid;
HSExtKeyMaterial::Ptr keyMaterial;
uint32_t srt_flag = 0xbf;
uint16_t delay = DurationCountMicroseconds(_now - _induction_ts) * getLatencyMul() / 1000;
if (delay <= 120) {
@ -241,6 +245,9 @@ void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockad
if (!sid) {
sid = std::dynamic_pointer_cast<HSExtStreamID>(ext);
}
if (!keyMaterial) {
keyMaterial = std::dynamic_pointer_cast<HSExtKeyMaterial>(ext);
}
}
if (sid) {
_stream_id = sid->streamid;
@ -252,6 +259,22 @@ void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockad
srt_flag = req->srt_flag;
delay = delay <= req->recv_tsbpd_delay ? req->recv_tsbpd_delay : delay;
}
if (!keyMaterial && getPassphrase().empty()) {
//nop
} else if (keyMaterial && !getPassphrase().empty()) {
_crypto = std::make_shared<SRT::Crypto>(getPassphrase());
if (!_crypto->loadFromKeyMaterial(keyMaterial)) {
sendRejectPacket(SRT_REJ_BADSECRET, addr);
onShutdown(SockException(Err_other, StrPrinter << "handshake fail, reject resaon: " << SRT::getRejectReason(SRT_REJ_BADSECRET)));
return;
}
} else {
sendRejectPacket(SRT_REJ_UNSECURE, addr);
onShutdown(SockException(Err_other, StrPrinter << "handshake fail, reject resaon: " << SRT::getRejectReason(SRT_REJ_UNSECURE)));
return;
}
TraceL << getIdentifier() << " CONCLUSION Phase from"<<SockUtil::inet_ntoa((struct sockaddr *)addr) << ":" << SockUtil::inet_port((struct sockaddr *)addr);;
HandshakePacket::Ptr res = std::make_shared<HandshakePacket>();
res->dst_socket_id = _peer_socket_id;
@ -262,6 +285,12 @@ void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockad
res->version = 5;
res->encryption_field = HandshakePacket::NO_ENCRYPTION;
res->extension_field = HandshakePacket::HS_EXT_FILED_HSREQ;
if (_crypto) {
//The default value is 0 (no encryption advertised).
//If neither peer advertises encryption, AES-128 is selected by default
/* req->encryption_field = SRT::HandshakePacket::AES_128; */
res->extension_field |= HandshakePacket::HS_EXT_FILED_KMREQ;
}
res->handshake_type = HandshakePacket::HS_TYPE_CONCLUSION;
res->srt_socket_id = _socket_id;
res->syn_cookie = 0;
@ -272,6 +301,10 @@ void SrtTransport::handleHandshakeConclusion(HandshakePacket &pkt, struct sockad
ext->srt_flag = srt_flag;
ext->recv_tsbpd_delay = ext->send_tsbpd_delay = delay;
res->ext_list.push_back(std::move(ext));
if (keyMaterial) {
keyMaterial->extension_type = HSExt::SRT_CMD_KMRSP;
res->ext_list.push_back(std::move(keyMaterial));
}
res->storeToData();
_handleshake_res = res;
unregisterSelfHandshake();
@ -366,6 +399,42 @@ void SrtTransport::sendMsgDropReq(uint32_t first, uint32_t last) {
sendControlPacket(pkt, true);
}
void SrtTransport::tryAnnounceKeyMaterial() {
//TraceL;
if (!_crypto) {
return;
}
auto pkt = _crypto->takeAwayAnnouncePacket();
if (!pkt) {
return;
}
auto now = SteadyClock::now();
pkt->dst_socket_id = _peer_socket_id;
pkt->timestamp = SRT::DurationCountMicroseconds(now - _start_timestamp);
pkt->storeToData();
_announce_req = pkt;
sendControlPacket(pkt, true);
std::weak_ptr<SrtTransport> weak_self = std::static_pointer_cast<SrtTransport>(shared_from_this());
_announce_timer = std::make_shared<Timer>(0.2, [weak_self]()->bool{
auto strong_self = weak_self.lock();
if (!strong_self) {
return false;
}
if (!strong_self->_announce_req) {
return false;
}
strong_self->sendControlPacket(strong_self->_announce_req, true);
return true;
}, getPoller());
return;
}
void SrtTransport::handleNAK(uint8_t *buf, int len, struct sockaddr_storage *addr) {
// TraceL;
NAKPacket pkt;
@ -433,6 +502,8 @@ void SrtTransport::handleDropReq(uint8_t *buf, int len, struct sockaddr_storage
*/
}
void SrtTransport::checkAndSendAckNak(){
//SRT Periodic NAK reports are sent with a period of (RTT + 4 * RTTVar) / 2 (so called NAKInterval),
//with a 20 milliseconds floor
auto nak_interval = (_rtt + _rtt_variance * 4) / 2;
if (nak_interval <= 20 * 1000) {
nak_interval = 20 * 1000;
@ -468,7 +539,52 @@ void SrtTransport::checkAndSendAckNak(){
_light_ack_pkt_count++;
}
void SrtTransport::handleUserDefinedType(uint8_t *buf, int len, struct sockaddr_storage *addr) {
TraceL;
/* TraceL; */
using srt_userd_defined_handler = void (SrtTransport::*)(uint8_t * buf, int len, struct sockaddr_storage *addr);
static std::unordered_map<uint16_t /*sub_type*/, srt_userd_defined_handler> s_userd_defined_functions;
static onceToken token([]() {
s_userd_defined_functions.emplace(SRT::HSExt::SRT_CMD_KMREQ, &SrtTransport::handleKeyMaterialReqPacket);
s_userd_defined_functions.emplace(SRT::HSExt::SRT_CMD_KMRSP, &SrtTransport::handleKeyMaterialRspPacket);
});
uint16_t subtype = ControlPacket::getSubType(buf, len);
auto it = s_userd_defined_functions.find(subtype);
if (it == s_userd_defined_functions.end()) {
WarnL << " not support subtype in user defined msg ignore: " << subtype;
return;
} else {
(this->*(it->second))(buf, len, addr);
}
return;
}
void SrtTransport::handleKeyMaterialReqPacket(uint8_t *buf, int len, struct sockaddr_storage *addr) {
/* TraceL; */
if (!_crypto) {
WarnL << " not enable crypto, ignore";
return;
}
KeyMaterialPacket::Ptr pkt = std::make_shared<KeyMaterialPacket>();
pkt->loadFromData(buf, len);
_crypto->loadFromKeyMaterial(pkt);
//rsp
pkt->sub_type = SRT::HSExt::SRT_CMD_KMRSP;
pkt->dst_socket_id = _peer_socket_id;
pkt->timestamp = DurationCountMicroseconds(_now - _start_timestamp);
pkt->storeToData();
sendControlPacket(pkt, true);
return;
}
void SrtTransport::handleKeyMaterialRspPacket(uint8_t *buf, int len, struct sockaddr_storage *addr) {
/* TraceL; */
_announce_req = nullptr;
return;
}
void SrtTransport::handleACKACK(uint8_t *buf, int len, struct sockaddr_storage *addr) {
@ -603,6 +719,25 @@ void SrtTransport::sendNAKPacket(std::list<PacketQueue::LostPair> &lost_list) {
// TraceL<<"send NAK "<<pkt->dump();
}
void SrtTransport::sendRejectPacket(SRT_REJECT_REASON reason, struct sockaddr_storage *addr) {
HandshakePacket::Ptr res = std::make_shared<HandshakePacket>();
res->dst_socket_id = _peer_socket_id;
res->timestamp = DurationCountMicroseconds(_now - _start_timestamp);
res->mtu = _mtu;
res->max_flow_window_size = _max_window_size;
res->initial_packet_sequence_number = _init_seq_number;
res->version = 5;
res->encryption_field = HandshakePacket::NO_ENCRYPTION;
res->extension_field = HandshakePacket::HS_EXT_FILED_HSREQ;
res->handshake_type = reason;
res->srt_socket_id = _socket_id;
res->syn_cookie = 0;
res->assignPeerIP(addr);
res->storeToData();
sendControlPacket(res, true);
return;
}
void SrtTransport::sendShutDown() {
ShutDownPacket::Ptr pkt = std::make_shared<ShutDownPacket>();
pkt->dst_socket_id = _peer_socket_id;
@ -615,6 +750,16 @@ void SrtTransport::handleDataPacket(uint8_t *buf, int len, struct sockaddr_stora
DataPacket::Ptr pkt = std::make_shared<DataPacket>();
pkt->loadFromData(buf, len);
if (_crypto) {
auto payload = _crypto->decrypt(pkt, pkt->payloadData(), pkt->payloadSize());
if (!payload) {
WarnL << "decrypt pkt->packet_seq_number: " << pkt->packet_seq_number << ", timestamp: " << "pkt->timestamp " << " fail";
return;
}
pkt->reloadPayload((uint8_t*)payload->data(), payload->size());
}
_estimated_link_capacity_context->inputPacket(_now,pkt);
std::list<DataPacket::Ptr> list;
@ -684,9 +829,26 @@ void SrtTransport::handleDataPacket(uint8_t *buf, int len, struct sockaddr_stora
}
void SrtTransport::sendDataPacket(DataPacket::Ptr pkt, char *buf, int len, bool flush) {
pkt->storeToData((uint8_t *)buf, len);
auto data = buf;
auto size = len;
BufferLikeString::Ptr payload;
if (_crypto) {
payload = _crypto->encrypt(pkt, const_cast<char*>(buf), len);
if (!payload) {
WarnL << "encrypt pkt->packet_seq_number: " << pkt->packet_seq_number << ", timestamp: " << "pkt->timestamp " << " fail";
return;
}
data = payload->data();
size = payload->size();
tryAnnounceKeyMaterial();
}
pkt->storeToData((uint8_t *)data, size);
sendPacket(pkt, flush);
_send_buf->inputPacket(pkt);
return;
}
void SrtTransport::sendControlPacket(ControlPacket::Ptr pkt, bool flush) {
@ -836,4 +998,4 @@ SrtTransport::Ptr SrtTransportManager::getHandshakeItem(const uint32_t key) {
return it->second.lock();
}
} // namespace SRT
} // namespace SRT

View File

@ -13,6 +13,7 @@
#include "Common.hpp"
#include "NackContext.hpp"
#include "Packet.hpp"
#include "Crypto.hpp"
#include "PacketQueue.hpp"
#include "PacketSendQueue.hpp"
#include "Statistic.hpp"
@ -24,6 +25,7 @@ extern const std::string kPort;
extern const std::string kTimeOutSec;
extern const std::string kLatencyMul;
extern const std::string kPktBufSize;
extern const std::string kPassPhrase;
class SrtTransport : public std::enable_shared_from_this<SrtTransport> {
public:
@ -60,6 +62,7 @@ protected:
virtual int getLatencyMul() { return 4; };
virtual int getPktBufSize() { return 8192; };
virtual float getTimeOutSec(){return 5.0;};
virtual std::string getPassphrase() {return "";};
private:
void registerSelf();
@ -79,15 +82,19 @@ private:
void handleShutDown(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleDropReq(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleUserDefinedType(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleKeyMaterialReqPacket(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleKeyMaterialRspPacket(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handlePeerError(uint8_t *buf, int len, struct sockaddr_storage *addr);
void handleDataPacket(uint8_t *buf, int len, struct sockaddr_storage *addr);
void sendNAKPacket(std::list<PacketQueue::LostPair> &lost_list);
void sendACKPacket();
void sendRejectPacket(SRT_REJECT_REASON reason, struct sockaddr_storage *addr);
void sendLightACKPacket();
void sendKeepLivePacket();
void sendShutDown();
void sendMsgDropReq(uint32_t first, uint32_t last);
void tryAnnounceKeyMaterial();
size_t getPayloadSize() const;
@ -159,6 +166,11 @@ private:
Ticker _alive_ticker;
bool _is_handleshake_finished = false;
// for encryption
Crypto::Ptr _crypto;
Timer::Ptr _announce_timer;
KeyMaterialPacket::Ptr _announce_req;
};
class SrtTransportManager {
@ -185,4 +197,4 @@ private:
} // namespace SRT
#endif // ZLMEDIAKIT_SRT_TRANSPORT_H
#endif // ZLMEDIAKIT_SRT_TRANSPORT_H

View File

@ -370,6 +370,11 @@ float SrtTransportImp::getTimeOutSec() {
return timeOutSec;
}
std::string SrtTransportImp::getPassphrase() {
GET_CONFIG(string, passphrase, kPassPhrase);
return passphrase;
}
int SrtTransportImp::getPktBufSize() {
// kPktBufSize
GET_CONFIG(int, pktBufSize, kPktBufSize);
@ -380,4 +385,4 @@ int SrtTransportImp::getPktBufSize() {
return pktBufSize;
}
} // namespace SRT
} // namespace SRT

View File

@ -38,6 +38,7 @@ protected:
int getLatencyMul() override;
int getPktBufSize() override;
float getTimeOutSec() override;
std::string getPassphrase() override;
void onSRTData(DataPacket::Ptr pkt) override;
void onShutdown(const SockException &ex) override;
void onHandShakeFinished(std::string &streamid, struct sockaddr_storage *addr) override;